Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about qk_matmul's load_act calculation in prefill stage. #19

Open
zzhbrr opened this issue Dec 11, 2024 · 2 comments
Open

Question about qk_matmul's load_act calculation in prefill stage. #19

zzhbrr opened this issue Dec 11, 2024 · 2 comments

Comments

@zzhbrr
Copy link

zzhbrr commented Dec 11, 2024

Hi!

I think something is wrong when calculating load_act of qk_matmul in prefill stage.

From my understanding, the load_act of qk_matmul should be calculated as: load_act=seqlen * head_size * batchsize * num_attention_heads * a_byte.However, in the code at model_analyzer.py#L359, it is written as: load_act=seqlen * head_size * batchsize * num_key_value_heads * a_byte.

Could it be that I’m misunderstanding some fundamental concepts, or is there a potential issue with the code?

Thanks!

@yyjpro
Copy link

yyjpro commented Dec 16, 2024

I think load_act=seqlen * head_size * batchsize * num_key_value_heads * a_byte is correct, If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. num_key_value_heads would be general usage.

@zzhbrr
Copy link
Author

zzhbrr commented Dec 17, 2024

Hi @yyjpro.

In qk_matmul, we need to load Q matrix and K matrix. Based on my understanding, the shape of Q matrix is [batchsize, num_attention_heads, seqlen, head_size], and the shape of K matrix is [batchsize, num_key_value_heads, seqlen, head_size]. Therefore, load_act=seqlen * head_size * batchsize * num_attention_heads * a_byte and load_kv_cache=seqlen * head_size * batchsize * num_kv_heads * kv_byte. Just as the formula in decode stage model_analyzer.py#L264

@zzhbrr zzhbrr closed this as completed Dec 17, 2024
@zzhbrr zzhbrr reopened this Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants