Skip to content

Commit

Permalink
fix swa
Browse files Browse the repository at this point in the history
  • Loading branch information
sunyt32 committed Aug 26, 2024
1 parent 378d428 commit f4f7b27
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions YOCO/yoco/models/decoder/sliding_window_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, args):
self.args = args
self.embed_dim = args.dim
self.num_heads = args.n_self_heads // args.model_parallel_size
self.window_size = args.sliding_window - 1 # compatible with flash attention
self.window_size = args.sliding_window

self.head_dim = args.dim // args.n_self_heads

Expand Down Expand Up @@ -60,8 +60,9 @@ def forward(
else:
incremental_state["prev_key"][:bsz, start_pos : start_pos + tgt_len] = k
incremental_state["prev_value"][:bsz, start_pos : start_pos + tgt_len] = v

attn = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size - 1, 0))
else:
key, value = k, v
attn = flash_attn_func(q, key, value, causal=True, window_size=(self.window_size - 1, 0))
attn = attn.reshape(bsz, tgt_len, self.head_dim * self.num_heads)

attn = self.out_proj(attn)
Expand Down

0 comments on commit f4f7b27

Please sign in to comment.