Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlp_act computation problem #16

Open
TaoZQY opened this issue Oct 25, 2024 · 5 comments
Open

mlp_act computation problem #16

TaoZQY opened this issue Oct 25, 2024 · 5 comments

Comments

@TaoZQY
Copy link

TaoZQY commented Oct 25, 2024

The amount of computation for mlp_act in the latest code is

 for name in ["mlp_act"]:
            self._analyze_to_results(
                "prefill",
                name,
                OPs=batchsize * hidden_size * seqlen * 1 * 2,
                load_weight=0,
                load_act=batchsize * hidden_size * seqlen * a_byte * 2,
                store_act=batchsize * hidden_size * seqlen * a_byte,
                load_kv_cache=0,
                store_kv_cache=0,
     )

I have the following questions:

  1. The dimension of the input to the LLama activation layer is intermediate_size // tp_size
    "gate_proj":[hidden_size, intermediate_size // tp_size],
    "up_proj":[hidden_size,intermediate_size // tp_size],
    The calculation for whether to activate should be intermediate_size // tp_size?

  2. The SILU is calculated by the formula SILU(x)=x⋅sigmoid(x)

Overall, whether the amount of computation of mlp_act should be:

 for name in ["mlp_act"]:
            self._analyze_to_results(
                "prefill",
                name,
                OPs=batchsize * intermediate_size // tp_size* seqlen * 1 * 4,
                load_weight=0,
                load_act=batchsize * intermediate_size // tp_size * seqlen * a_byte * 2,
                store_act=batchsize * intermediate_size // tp_size * seqlen * a_byte,
                load_kv_cache=0,
                store_kv_cache=0,
     )
@TaoZQY
Copy link
Author

TaoZQY commented Oct 25, 2024

I am still a little confused about the lm-head OPs computation
The lm-head input is (1,hiddensize), the output is (1, vocab_size),
The OPs is args['batchsize']hiddensizevocab_size*2 (Matrix multiplication)
But the code is

def `post_process(model_params,args):`
    hiddensize=get_hidden_size(model_params)
    vocab_size=get_vocab_size(model_params)
    layers=[]
    for stage in ["prefill", "decode"]:
        layers.append({
            'name': 'lm_head',
            'stage':stage,
            'OPs':args['batchsize']*hiddensize*vocab_size*1,
            'load_weight':hiddensize*vocab_size *args['w_byte'],
            'load_act':hiddensize*args['a_byte'],
            'store_act':vocab_size*args['a_byte'],
        })
    return layers

Overall, whether the amount of computation of lm-head should be:

def post_process(model_params,args):
    hiddensize=get_hidden_size(model_params)
    vocab_size=get_vocab_size(model_params)
    layers=[]
    for stage in ["prefill", "decode"]:
        layers.append({
            'name': 'lm_head',
            'stage':stage,
            'OPs':args['batchsize']*hiddensize*vocab_size*2, // Matrix multiplication
            'load_weight':hiddensize*vocab_size *args['w_byte'],
            'load_act':hiddensize*args['a_byte'],
            'store_act':vocab_size*args['a_byte'],
        })
    return layers

@TaoZQY
Copy link
Author

TaoZQY commented Oct 26, 2024

weight_kv_footprint = total_results["prefill"]["load_weight"] + total_results["prefill"]["store_kv_cache"]
decode_tmp_act = 0
for layer_name, result in self.results["decode"].items():
    decode_tmp_act += result["store_act"] # activation is discarded after one layer
total_results["decode"]["memory_consumption"] = decode_tmp_act + weight_kv_footprint
total_results["decode"]["memory_consumption_tmp_act"] = decode_tmp_act
total_results["decode"]["memory_consumption_weight"] = total_results["prefill"]["load_weight"]
total_results["decode"]["memory_consumption_kv_cache"] = total_results["prefill"]["store_kv_cache"]
prefill_tmp_act = 0
for layer_name, result in self.results["prefill"].items():
    prefill_tmp_act += result["store_act"]
total_results["prefill"]["memory_consumption"] = prefill_tmp_act + weight_kv_footprint
total_results["prefill"]["memory_consumption_tmp_act"] = prefill_tmp_act
total_results["prefill"]["memory_consumption_weight"] = total_results["prefill"]["load_weight"]
total_results["prefill"]["memory_consumption_kv_cache"] = total_results["prefill"]["store_kv_cache"]

# lm_head
name = "lm_head"
args = {"batchsize": batchsize, "a_byte": a_byte, "w_byte": w_byte}
for layer_info in self.config.post_process(self.model_params, args):
    self._analyze_to_results(**layer_info)
    for data_name in ALL_DATA_NAMES:
        total_results[layer_info["stage"]][data_name] += self.results[layer_info["stage"]][layer_info["name"]][
            data_name
        ]
  1. The lm_head is included in both the prefill and decode phases, but the order of this code is not correct. The number of lm_head parameters should be added to totalresult first, and then the memory consumption between prefill and deocde in totalresult is calculated

  2. The decode stage Kv cache memory consumption should include the kv store of the prefill and decode itself (although small).

Overall, the correct code should be:

lm_head

name = "lm_head"
args = {"batchsize": batchsize, "a_byte": a_byte, "w_byte": w_byte}
for layer_info in self.config.post_process(self.model_params, args):
    self._analyze_to_results(**layer_info)
    for data_name in ALL_DATA_NAMES:
        total_results[layer_info["stage"]][data_name] += self.results[layer_info["stage"]][layer_info["name"]][
            data_name
        ]

weight_kv_footprint = total_results["prefill"]["load_weight"] + total_results["prefill"]["store_kv_cache"]+ total_results["decode"]["store_kv_cache"]
decode_tmp_act = 0
for layer_name, result in self.results["decode"].items():
    decode_tmp_act += result["store_act"] # activation is discarded after one layer
total_results["decode"]["memory_consumption"] = decode_tmp_act + weight_kv_footprint
total_results["decode"]["memory_consumption_tmp_act"] = decode_tmp_act
total_results["decode"]["memory_consumption_weight"] = total_results["prefill"]["load_weight"]
total_results["decode"]["memory_consumption_kv_cache"] = total_results["prefill"]["store_kv_cache"]+ total_results["decode"]["store_kv_cache"]

prefill_tmp_act = 0
for layer_name, result in self.results["prefill"].items():
    prefill_tmp_act += result["store_act"]
total_results["prefill"]["memory_consumption"] = prefill_tmp_act + weight_kv_footprint
total_results["prefill"]["memory_consumption_tmp_act"] = prefill_tmp_act
total_results["prefill"]["memory_consumption_weight"] = total_results["prefill"]["load_weight"]
total_results["prefill"]["memory_consumption_kv_cache"] = total_results["prefill"]["store_kv_cache"]

@TaoZQY
Copy link
Author

TaoZQY commented Oct 27, 2024

block_size_r = min(math.ceil(onchip_buffer / (kv_byte * head_size)), head_size)
In the flashattention2 paper, Br=[M/4d], Is there a problem?

for decode stage , o_nume=[1,d]
o_numel = 1 * seqlen * batchsize * num_attention_heads * a_byte Is there a problem?
The correct code may o_numel = 1 * head_size * batchsize * num_attention_heads * a_byte ?

@hahnyuan
Copy link
Owner

Thank you for your thorough code review and detailed analysis. You've identified several important calculation adjustments that need to be made:

  1. For the MLP activation layer:
  • The calculation should indeed account for intermediate_size // tp_size instead of hidden_size
  • The SILU activation's computational cost (x⋅sigmoid(x)) should be reflected in the OPs count (each element we need approximately 5 operations (4 for sigmoid (sub, exp, add, div) + 1 for final multiplication).)
  1. For the LM head computation:
  • The matrix multiplication OPs should be doubled to account for the full computation
  1. Regarding memory consumption calculation:
  • The ordering of lm_head parameter addition and memory consumption calculation needs to be fixed
  • The decode stage KV cache memory consumption should include both prefill and decode KV store
  1. For FlashAttention-2 related block sizes parameters:
  • After reviewing the FlashAttention code again. I've found it challenging to determine the optimal block size calculation. The block size should ideally be tuned for each specific hardware setup. While theoretical calculations are valuable, I have not yet figured out the appropriate method to do this.
  1. The output tensor dimensions for the decode stage need to be adjusted

We'll update these calculation methods in the code. Let’s ensure that our calculations reflect the correct logic and dimensions as you've outlined.
I hadn't realized there were so many issues in this code. It’s clear that this project needs a thorough revision and update. If you're interested, I’d love to discuss potential next steps.

@TaoZQY
Copy link
Author

TaoZQY commented Oct 27, 2024

Dear author, I'm truly grateful that you could notice my question in time. Your professionalism and responsibility have deeply impressed me, and I sincerely hope that I can learn more from you in the future. Next, I plan to modify the code and submit a pull request for you to review. I'm looking forward to having the opportunity to cooperate with you and your team to make progress together. Thank you again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants