You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
According to the code in the FA2 paper, there might be less details about online softmax in the project.
Analyzing each step, in the inner loop:
S=QK OPs: 2 * d * Bc * Br
max operation has no FLOP; solving P needs 2 * Br * Bc, and solving l needs (2Br + Br + Br * Bc), 2Br + Br means exp in the online softmax process, Br * Bc means rowsum in P
solving O need total (2Br + Br * d + 2 * d * Bc * Br), but here could be (2Br + Br * d * Br + 2 * d * Bc * Br) if it was diag matrix mul in the kernel. In this case, diag matrix has many 0 elements, so in the actual env, it could be eliminated.
After all inner loop finished, total inner OPs: (N / Bc) * (4 * d * Bc * Br + 5Br + 3Br * Bc + Br *d)
Adding last calculation OPs: Br * d
In the Outer loop, after multiple N / Br, get: 4 * d * N^2 + 5N^2/Bc + 3N^2 + d * N^2/Bc + Nd
In the original project, 4 * d * N^2 fully corresponds to the prefill stage qk and sv operations: qk_matmul_OPs = seqlen * seqlen * head_size * num_attention_heads * batchsize * 2 , sv_matmul_OPs = seqlen * head_size * seqlen * num_attention_heads * batchsize * 2
but the softmax part cannot be matched. Is it necessary to reconsider a greater OPs for online softmax? Additionally, the inference time seems to theoretically depend only on the formula OPs/bandwidth. If my analysis is reasonable, an increase in actual OPs would lead to an increase in FA's inference time (bigger than normal attention), which clearly does not align with the practical situation. How should this be balanced theoretically?
The text was updated successfully, but these errors were encountered:
According to the code in the FA2 paper, there might be less details about online softmax in the project.
Analyzing each step, in the inner loop:
After all inner loop finished, total inner OPs: (N / Bc) * (4 * d * Bc * Br + 5Br + 3Br * Bc + Br *d)
Adding last calculation OPs: Br * d
In the Outer loop, after multiple N / Br, get: 4 * d * N^2 + 5N^2/Bc + 3N^2 + d * N^2/Bc + Nd
In the original project, 4 * d * N^2 fully corresponds to the prefill stage qk and sv operations:
qk_matmul_OPs = seqlen * seqlen * head_size * num_attention_heads * batchsize * 2
,sv_matmul_OPs = seqlen * head_size * seqlen * num_attention_heads * batchsize * 2
but the softmax part cannot be matched. Is it necessary to reconsider a greater OPs for online softmax? Additionally, the inference time seems to theoretically depend only on the formula OPs/bandwidth. If my analysis is reasonable, an increase in actual OPs would lead to an increase in FA's inference time (bigger than normal attention), which clearly does not align with the practical situation. How should this be balanced theoretically?
The text was updated successfully, but these errors were encountered: