-
Notifications
You must be signed in to change notification settings - Fork 25
/
long_context_example.py
51 lines (46 loc) · 1.95 KB
/
long_context_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# LLaMA model with KIVI
import warnings
warnings.filterwarnings("ignore")
import torch
import json
from models.llama_kivi import LlamaForCausalLM_KIVI
from transformers import LlamaConfig, AutoTokenizer
from datasets import load_dataset
# here we use lmsys/vicuna-7b-v1.5-16k as the base model. It support long context inference up to 16k.
config = LlamaConfig.from_pretrained("lmsys/vicuna-7b-v1.5-16k")
config.k_bits = 2 # KiVi currently support 2/4 K/V bits
config.v_bits = 2
config.group_size = 32
config.residual_length = 32 # corresponding to the number of recent fp16 tokens
config.use_flash = True # use flash-attention with KiVi for long context inference
CACHE_DIR = "/scratch/cached_model"
model = LlamaForCausalLM_KIVI.from_pretrained(
pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5-16k",
config=config,
cache_dir=CACHE_DIR,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
).cuda()
enc = AutoTokenizer.from_pretrained(
'lmsys/vicuna-7b-v1.5-16k',
use_fast=False,
trust_remote_code=True,
tokenizer_type='llama')
model.eval()
file_name = "passkey_examples.jsonl"
method_name = f"K{config.k_bits}V{config.v_bits} KiVi"
print("=========="*2 + f"**{method_name}**" + "=========="*2)
for line in open(file_name, "r"):
example = json.loads(line)
prompt_postfix = "What is the pass key? The pass key is "
prompt = example["input"] + prompt_postfix
input_ids = enc(prompt, return_tensors="pt").input_ids.cuda()
print( "-----------------------------------" )
print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " )
print( "Passkey target:", example["target"] )
tokens = model.generate(input_ids, max_new_tokens=len(example["target"]))
answer = prompt_postfix + enc.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
answer = answer.replace("\n", "\\n")
answer= f"{method_name}:\n [ {answer} ]"
print( answer )
print( "-----------------------------------\n" )