Skip to content

Commit

Permalink
dataset debug and make llama compatible with FlashAttn2
Browse files Browse the repository at this point in the history
  • Loading branch information
lxr-tech committed Jul 24, 2023
1 parent 7a6f38e commit 6a03e40
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
5 changes: 2 additions & 3 deletions collie/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,11 @@ def save_propressed(self, path: str, shard_size: int = 4):
shard_idx = 0
meta = np.empty((0, 2), int)
for i in self.indices:
labels = self[i][1]
data = {
"tokens": labels["labels"]
"tokens": self[i]["input_ids"]
}
data.update(
{key: value for key, value in labels.items() if key != "labels"})
{key: value for key, value in self[i].items() if key != "input_ids"})
bytes_data = json.dumps(data).encode() + "\n".encode()
offset = shard.tell()
length = len(data["tokens"])
Expand Down
9 changes: 7 additions & 2 deletions collie/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from einops import rearrange

try:
from flash_attn.flash_attention import FlashAttention
from flash_attn.modules.mha import SelfAttention as FlashAttention
except ModuleNotFoundError:
FlashAttention = None

Expand Down Expand Up @@ -194,7 +194,12 @@ def _forward(self,
assert FlashAttention is not None, \
"Detected flash_attn is not installed. See https://github.com/HazyResearch/flash-attention"
qkv = torch.stack([query, key, value], dim=2)
output, _ = FlashAttention()(qkv, key_padding_mask=attention_mask, causal=True)
output = FlashAttention()(qkv, key_padding_mask=attention_mask.bool(), causal=True)
""" flash_attn_2 note:
from flash_attn.modules.mha import SelfAttention as FlashAttention
require attention_mask as a bool tensor
replace 'output, _ =' as 'output ='
"""
output = rearrange(output, "b n h d -> b n (h d)")
else:
query, key, value = query.permute(0, 2, 1, 3), key.permute(
Expand Down

0 comments on commit 6a03e40

Please sign in to comment.