Skip to content

Commit

Permalink
yoco init
Browse files Browse the repository at this point in the history
  • Loading branch information
donglixp committed May 9, 2024
1 parent 50c5700 commit 7402b0e
Show file tree
Hide file tree
Showing 42 changed files with 4,513 additions and 4 deletions.
170 changes: 166 additions & 4 deletions YOCO/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,168 @@
# YOCO
# You Only Cache Once: Decoder-Decoder Architectures for Large Language Models

- May 2024: Code release
- May 2024: release preprint [YOCO](https://arxiv.org/abs/)
## Approach
<div align="center">
<img src="./imgs/arch.png" width=60%/>
</div>

## Getting Started
<div align="center">
<img src="./imgs/inference.png" width=50%/>
</div>

## Performance
### Harness Eval
Training with 1T Tokens:
| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** |
|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------|
| OpenLLaMA-3B-v2 | 0.339 | 0.676 | 0.657 | **0.700** | 0.260 | 0.767 | 0.629 | 0.924 | 0.619 |
| StableLM-base-alpha-3B-v2 | 0.324 | 0.673 | 0.646 | 0.686 | 0.264 | 0.760 | 0.621 | 0.921 | 0.612 |
| StableLM-3B-4E1T | --- | 0.666 | --- | --- | --- | **0.768**| 0.632 | 0.914 | --- |
| YOCO-3B | **0.379** | **0.731** | 0.645 | 0.689 | **0.298**| 0.763 | 0.639 | 0.924 | **0.634**|

Training with 1.6T Tokens:
| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** |
|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------|
| StableLM-3B-4E1T | --- | 0.688 | --- | --- | --- | 0.762 | 0.627 | 0.913 | --- |
| YOCO-3B | 0.396 | 0.733 | **0.644** | 0.698 | 0.300 | 0.764 | 0.631 | 0.921 | 0.636 |
| YOCO-3B-1M | **0.413** | **0.747** | 0.638 | **0.705** | 0.300 | **0.773**| **0.651** | **0.932**| **0.645**|
### Needle In A Haystack
<div align="center">
<img src="./imgs/1m_retrieval.png"/>
</div>

### Multi-Needle Eval
| **Model** | **Size** | **N=1** | **N=2** | **N=4** | **N=8** |
|-------------------------|----------|---------|---------|---------|---------|
| GPT-4-128K | -- | 1.00 | 1.00 | 0.98 | 1.00 |
| MiniCPM-128K | 2.4B | 1.00 | 1.00 | 0.54 | 0.56 |
| ChatGLM3-128K | 6B | 0.94 | 0.72 | 0.52 | 0.44 |
| YaRN-Mistral-128K | 7B | 0.02 | 0.12 | 0.08 | 0.20 |
| LWM-1M-text | 7B | 1.00 | 0.90 | 0.76 | 0.62 |
| YOCO-3B-1M | 3B | 0.98 | 0.98 | 0.84 | 0.56 |

## Setup

To install the required packages, use the following command:

```bash
pip install -r requirements.txt
```

Besides normal packages, [Apex](https://github.com/NVIDIA/apex) and [Flash-Attention](https://github.com/Dao-AILab/flash-attention) should be installed seperately following their offcial guidences.

## Harness Eval

To evaluate models in Harness-Eval, the script is as follows in ```scripts/eval_task.sh```:
```bash
cd fairseq/
TASK='harness_boolq'

torchrun --master-port=29505 --nproc_per_node=1 validate.py \
--data-dir ../harness_data/ \
--criterion harness_eval \
--task harness_eval \
--batch-size 4 \
--eval-data ${TASK} \
--log-format simple --log-interval 10 \
--bf16 \
--tokenizer-pad-to-multiple 8 \
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 4096
```

## Needle In A Haystack Evaluation
Our model uses city-number pairs for long sequence evaluation. To get the results at a certain maximal length, the script is as follows in ```scripts/eval_needle.sh```:
```bash
cd fairseq/
torchrun --master-port=29504 --nproc_per_node=1 validate.py \
--task pseudo \
--criterion needle_haystack \
--batch-size 1 \
--max-epoch 1 \
--no-save \
--tiktoken-model cl100k_base \
--bf16 \
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576
```

To run Multi-Needle experiments, replace ```--criterion needle_haystack``` with ```--criterion multi_needle --needle-num {num}```.

## Pretraining From Scratch
To support distributed training, our implementation is based on infinibatch to read data iteratively. The overall data directory should be organized as follows:
```
Data/
├── json/
│ ├── train.json
│ └── CC.json
│ └── StarCoder.json
│ └── ...
├── shard/
│ ├── CC/
│ │ ├── 00000.jsonl
│ │ ├── 00001.jsonl
│ │ └── ...
│ └── StarCoder/
│ ├── 00000.jsonl
│ ├── 00001.jsonl
│ └── ...
```

We recommend that each sharded data files contains no more than 10K lines with one json dict per line, and jsonl file, such as ```Data/shard/CC/00000.jsonl```, should be in the format like this:
```json
{"text": "File 1 is here..."}
{"text": "File 2 is here..."}
...
```

Then, for each source, a JSON file preserves all the paths of the jsonl files. Take ```Data/json/CC.json``` for example:
```json
[
"/path_to_data/Data/shard/CC/00000.jsonl",
"/path_to_data/Data/shard/CC/00001.jsonl",
...
]
```

Finally, ```train.json``` records all sources' information and sampling ratio:
```json
[
{
"name": "CC",
"weight": 0.5
},
{
"name": "StarCoder",
"weight": 0.2
},
...
]
```

```scripts/train.sh```:
```bash
cd fairseq/
torchrun --nproc-per-node=1 train.py /path_to_data \
--save-interval-updates 5000 \
--no-epoch-checkpoints \
--arch yoco_base \
--criterion cross_entropy \
--task gpt \
--tokens-per-sample 2048 \
--tokenizer-pad-to-multiple 8 \
--pad-to-max-len \
--optimizer adam --adam-betas "(0.9, 0.95)" \
--adam-eps 1e-06 \
--clip-norm 2.0 \
--lr 0.00015 \
--lr-scheduler polynomial_decay \
--warmup-updates 50 \
--weight-decay 0.05 \
--batch-size 1 \
--model-parallel-size 1 \
--update-freq 1 \
--batch-read-ahead 1000 \
--total-num-update 300000 \
--log-format simple --log-interval 10 --disable-validation \
--tiktoken-model cl100k_base \
--save-interval-updates 5000 \
--bf16 # bf16 is encouraged in pre-training
```
Binary file added YOCO/imgs/1m_retrieval.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added YOCO/imgs/arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added YOCO/imgs/inference.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions YOCO/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
torch>=2.2.0
triton>=2.2.0
numpy==1.23.0
fairscale
tiktoken
sentencepiece
ninja
boto3
iopath
git+https://github.com/sunyt32/fairseq.git@moe3#egg=fairseq
git+https://github.com/shumingma/infinibatch.git#egg=infinibatch
git+https://github.com/microsoft/torchscale.git#egg=torchscale
11 changes: 11 additions & 0 deletions YOCO/scripts/eval_needle.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cd yoco/
torchrun --master-port=29504 --nproc_per_node=1 validate.py \
--task pseudo \
--criterion multi_needle --needle-num 4 \
--batch-size 1 \
--max-epoch 1 \
--no-save \
--tiktoken-model cl100k_base \
--bf16 \
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576

17 changes: 17 additions & 0 deletions YOCO/scripts/eval_task.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
TASK='harness_boolq'
# TASK='hendrycksTest-abstract_algebra'

cd yoco/
torchrun --master-port=29505 --nproc_per_node=1 validate.py \
--data-dir ../harness_data/ \
--criterion harness_eval \
--task harness_eval \
--batch-size 4 \
--eval-data ${TASK} \
--log-format simple --log-interval 10 \
--bf16 \
--tokenizer-pad-to-multiple 8 \
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 4096
# --arch llama_from_ckpt --llama-model /data/yutao/llama/llama-2-7b --load-ckpt /data/yutao/llama/llama-2-7b/consolidated.00.pth --tokens-per-sample 4096


27 changes: 27 additions & 0 deletions YOCO/scripts/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
cd yoco/
torchrun --master-port=29501 --nproc-per-node=1 train.py /mnt/nlcredstone/shaohanh/data/redstone_v4_21_config \
--save-interval-updates 5000 \
--no-epoch-checkpoints \
--arch yoco_base \
--criterion cross_entropy \
--task gpt \
--tokens-per-sample 2048 \
--tokenizer-pad-to-multiple 8 \
--pad-to-max-len \
--optimizer adam --adam-betas "(0.9, 0.95)" \
--adam-eps 1e-06 \
--clip-norm 2.0 \
--lr 0.00015 \
--lr-scheduler polynomial_decay \
--warmup-updates 50 \
--weight-decay 0.05 \
--batch-size 1 \
--model-parallel-size 1 \
--update-freq 1 \
--batch-read-ahead 1000 \
--total-num-update 300000 \
--log-format simple --log-interval 10 --disable-validation \
--tiktoken-model cl100k_base \
--no-save \
--bf16 \

2 changes: 2 additions & 0 deletions YOCO/yoco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
8 changes: 8 additions & 0 deletions YOCO/yoco/criterions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import importlib
import os

# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("criterions." + file_name)
86 changes: 86 additions & 0 deletions YOCO/yoco/criterions/harness_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn.functional as F

from fairseq import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass


@register_criterion("harness_eval", dataclass=FairseqDataclass)
class HarnessEvalCriterion(FairseqCriterion):
def __init__(self, cfg, task):
super().__init__(task)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
model.eval()
net_output, _ = model(sample["net_input"]["src_tokens"])
net_output = net_output[:, :-1, :]
targets = sample["net_input"]["src_tokens"][:, 1:]
loss_mask = sample["net_input"]["gpt_loss_mask"][:, 1:]
label_length = sample["net_input"]["label_length"]
loss = F.cross_entropy(
net_output.float().reshape(-1, net_output.size(-1)),
targets.reshape(-1),
reduction="none",
ignore_index=self.padding_idx,
).reshape(targets.size(0), -1)
loss = loss * loss_mask.int()
loss_norm = loss.sum(-1) / label_length.float()
loss = loss.sum(-1)

option_num = self.task.harness_task.class_num
labels = sample["targets"].view(-1)

assert sample["targets"].size(0) % option_num == 0
sample_size = sample["ntokens"]

pred_label = torch.argmin(loss.view(-1, option_num), dim=1)
pred_norm_label = torch.argmin(loss_norm.view(-1, option_num), dim=1)
target_label = labels.view(-1, option_num)[:, 0]

logging_output = {}

logging_output.update(
{
"loss": 0,
"nsentences": pred_label.size(0),
"sample_size": pred_label.size(0),
"ncorrect": (pred_label == target_label).sum().item(),
"ncorrect_norm": (pred_norm_label == target_label).sum().item(),
}
)

return loss, sample_size, logging_output

@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss = sum(log.get("loss", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
ncorrect_norm = sum(log.get("ncorrect_norm", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss / nsentences, nsentences, round=3
)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=2
)
metrics.log_scalar(
"accuracy_norm", 100.0 * ncorrect_norm / nsentences, nsentences, round=2
)

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
Loading

0 comments on commit 7402b0e

Please sign in to comment.