From 6a03e404274cf20096df4b43be2627a67b5e03b6 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Mon, 24 Jul 2023 10:29:11 +0800 Subject: [PATCH] dataset debug and make llama compatible with FlashAttn2 --- collie/data/dataset.py | 5 ++--- collie/models/llama/model.py | 9 +++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/collie/data/dataset.py b/collie/data/dataset.py index ec35b578..d48fd94b 100644 --- a/collie/data/dataset.py +++ b/collie/data/dataset.py @@ -198,12 +198,11 @@ def save_propressed(self, path: str, shard_size: int = 4): shard_idx = 0 meta = np.empty((0, 2), int) for i in self.indices: - labels = self[i][1] data = { - "tokens": labels["labels"] + "tokens": self[i]["input_ids"] } data.update( - {key: value for key, value in labels.items() if key != "labels"}) + {key: value for key, value in self[i].items() if key != "input_ids"}) bytes_data = json.dumps(data).encode() + "\n".encode() offset = shard.tell() length = len(data["tokens"]) diff --git a/collie/models/llama/model.py b/collie/models/llama/model.py index 4a0db442..44e9ef86 100644 --- a/collie/models/llama/model.py +++ b/collie/models/llama/model.py @@ -18,7 +18,7 @@ from einops import rearrange try: - from flash_attn.flash_attention import FlashAttention + from flash_attn.modules.mha import SelfAttention as FlashAttention except ModuleNotFoundError: FlashAttention = None @@ -194,7 +194,12 @@ def _forward(self, assert FlashAttention is not None, \ "Detected flash_attn is not installed. See https://github.com/HazyResearch/flash-attention" qkv = torch.stack([query, key, value], dim=2) - output, _ = FlashAttention()(qkv, key_padding_mask=attention_mask, causal=True) + output = FlashAttention()(qkv, key_padding_mask=attention_mask.bool(), causal=True) + """ flash_attn_2 note: + from flash_attn.modules.mha import SelfAttention as FlashAttention + require attention_mask as a bool tensor + replace 'output, _ =' as 'output =' + """ output = rearrange(output, "b n h d -> b n (h d)") else: query, key, value = query.permute(0, 2, 1, 3), key.permute(