Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run example.py with llama2-7B-hf only save 500MB kv cache memory conpared to base transformers ? #17

Open
riou-chen opened this issue May 29, 2024 · 2 comments

Comments

@riou-chen
Copy link

I run the example.py with llama2-7B-hf,set input length 4096 tokens,and output length 100 tokens. config.k_bits = 2, config.v_bits = 2. the kv cache occupy 5.6GB memory,only save about 500MB compared to base transformers. If k and v bits = 2, the kv cache should occupy less 1GB, but not ,why? And the inference speed not improved.

@zirui-ray-liu
Copy link
Collaborator

Thank you for your interests! Below I provide a few of my analysis. Hope it helps:

First, under 4096 context length, for llama-2-7b,

KV Cache size = BS (1, I assume) * context_len (4096+100) * 2 (K&V) * num_layers (32) * num_heads (32) * hidden_size (128) * num_bytes (2 for fp16) / 1024 ** 3 = 2.05 GB. Did you set larger batch size to get 5.6 GB memory?

Second, the example.py is the 5-shot GSM8K example. The input length is decided by the number of tokens inside the few-shot example. That said, the actual length cannot be easily controlled. Did you double check the input length fed to the model?

Third, what is your group size and residual length? If the default one is used, then under this 4096 length setting, the compression ratio is around 5X~6X.

Fourth, regarding the speedup. TL;DR, because your KV Cache is small and our current implmentation have not been optimized for the small KV Cache setting. If you want to have decent speed up, just enlarge your batch size and sequence length. For more details, please check our reply.

Hope it helps! Let me know more details regarding your experiments.

@zirui-ray-liu
Copy link
Collaborator

@riou-chen

Thank you for the patient.We just releaased a new branch develop, where we extensively optimize the codebase. I will write a new blog about the detailed optimization.

Now, since we rewrite the low level CUDA kernel, to use our new implmenetation, it requires you to rebuild the CUDA implementation through:

git check -b develop
git pull
cd quant && pip install -e .

Currently it only support Llama model. I have tested the new implementation with Llama-7B-hf on Longbench and the accuracy looks good.

Let me know if you have any problem with it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants