Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Completion only fine-tuning of instruction models with collections of HF datasets #1103

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
59e937c
Add input_masked loss calculation and batching w/ padding
chimezie Jun 7, 2024
8c1d33d
Merge branch 'ml-explore:main' into completion_only
chimezie Jun 16, 2024
0a3ec90
Merge branch 'ml-explore:main' into completion_only
chimezie Jun 24, 2024
1929f53
Merge branch 'ml-explore:main' into completion_only
chimezie Jun 28, 2024
9df7bbb
Generalize HF datasets to a collection of HF dataasets via `datasets`…
chimezie Nov 4, 2024
1f6c370
Updates to LoRA documentation
chimezie Nov 4, 2024
c721220
Fixes to config format in documentattion
chimezie Nov 4, 2024
04cf93d
Fixes to references to hf_datasets
chimezie Nov 4, 2024
e477060
Fix keyword argument invokation
chimezie Nov 4, 2024
24f40c3
Fix iteration over HF dataset collection
chimezie Nov 4, 2024
78b24a2
Fix index calculation
chimezie Nov 4, 2024
95fb224
Merge branch 'ml-explore:main' into completion_only
chimezie Nov 4, 2024
a1fbc52
Merge branch 'ml-explore:main' into completion_only
chimezie Nov 5, 2024
b7b3332
Replace iterate_input_masked_batches with iterate_delineated_batches,…
chimezie Nov 5, 2024
603dab5
Merge branch 'ml-explore:main' into completion_only
chimezie Nov 5, 2024
5579b48
Minor documentation update
chimezie Nov 5, 2024
e0d66f5
Merge remote-tracking branch 'origin/completion_only' into completion…
chimezie Nov 5, 2024
4b88c33
Updates CL lora tuner with input masking that uses default_loss (and …
chimezie Nov 6, 2024
3c76a25
Fix variable reference
chimezie Nov 6, 2024
e45ce38
Add ability to fetch raw prompt and completion text from completion d…
chimezie Nov 6, 2024
90e2da8
Minor fix
chimezie Nov 6, 2024
960ed79
Update sublist search and calculation of input id length
chimezie Nov 6, 2024
bfa6c29
Fix
chimezie Nov 7, 2024
7f89ace
Merge branch 'ml-explore:main' into completion_only
chimezie Nov 8, 2024
3080102
Merge branch 'ml-explore:main' into completion_only
chimezie Nov 9, 2024
01e330d
Add input masking for fine-tuning in documentation
chimezie Nov 10, 2024
791727f
Merge remote-tracking branch 'origin/completion_only' into completion…
chimezie Nov 10, 2024
0a42079
Merge branch 'refs/heads/fix_bos_dupe' into completion_only_fix_bos_dupe
chimezie Nov 10, 2024
cb73b95
Don't dupe BOS
chimezie Nov 10, 2024
4ddbb98
Update documentation
chimezie Nov 10, 2024
8cd0586
Merge branch 'completion_only' into completion_only_fix_bos_dupe
chimezie Nov 10, 2024
c5f37ac
Merge branch 'ml-explore:main' into completion_only_fix_bos_dupe
chimezie Nov 16, 2024
d89dce1
Merge branch 'ml-explore:main' into completion_only_fix_bos_dupe
chimezie Nov 21, 2024
7076c8f
Merge branch 'ml-explore:main' into completion_only_fix_bos_dupe
chimezie Nov 28, 2024
b308733
Merge branch 'ml-explore:main' into completion_only_fix_bos_dupe
chimezie Dec 3, 2024
2c41f15
Default for hf_datasets configuration
chimezie Dec 6, 2024
5d57e80
Merge remote-tracking branch 'origin/completion_only_fix_bos_dupe' in…
chimezie Dec 6, 2024
c65b69f
Merge branch 'ml-explore:main' into completion_only_fix_bos_dupe
chimezie Dec 6, 2024
6b0bbfd
Synch use of special tokens with iterate_batches
chimezie Dec 6, 2024
4349397
Merge remote-tracking branch 'origin/completion_only_fix_bos_dupe' in…
chimezie Dec 6, 2024
9a39f3b
Add response template (or token) argument
chimezie Dec 8, 2024
1ed63e9
Incorporate use of response template for completion masking
chimezie Dec 8, 2024
1981b13
Move response template to LoRA configuration
chimezie Dec 8, 2024
55339e7
Generalize the get_item method to all CompletionDatasets
chimezie Dec 9, 2024
85723f4
Merge branch 'ml-explore:main' into completion_only_fix_bos_dupe
chimezie Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,27 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.

### Input Masking
There are custom functions for masking the sequence of tokens associated with the `prompt` in a completion dataset
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
with masked input sequences, use the `--mask-inputs` argument.

This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
the sequence from loss calculations. For example (ChatML):

```yaml
response_template: "<|im_start|>assistant"
```

or (for the corresponding tokens of Gemma's response template)

```yaml
response_template: [106, 2516]
```


### Evaluate

To compute test set perplexity use:
Expand Down Expand Up @@ -267,7 +288,7 @@ it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.

Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
expects. Use a YAML config to specify the Hugging Face (HF) dataset arguments. For
example:

```
Expand All @@ -279,11 +300,29 @@ hf_dataset:

- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
dataset. Use `chat_feature` to specify the key for a chat dataset.

- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.

You can specify a list of HF datasets using the `hf_datasets` (plural) configuration, which is a list of records
each with the same structure as above. For example:

```yaml
hf_datasets:
- hf_dataset:
name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- hf_dataset:
name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```

- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

Expand Down
42 changes: 41 additions & 1 deletion llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import yaml

from .tokenizer_utils import TokenizerWrapper
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
Expand Down Expand Up @@ -58,7 +58,9 @@
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"hf_datasets": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"response_template": None,
}


Expand Down Expand Up @@ -91,6 +93,15 @@ def build_parser():
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)

parser.add_argument(
"--mask-inputs",
dest="mask_inputs",
action="store_true",
help="Whether to mask the inputs when training. Default is False.",
default=False,
)

parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -169,6 +180,13 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
from .tuner.trainer import (
default_loss,
input_masked_loss,
iterate_batches,
iterate_completion_batches,
)

model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
Expand Down Expand Up @@ -197,6 +215,17 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

if isinstance(args.response_template, str):
response_generation_tokens = tokenizer.encode(
args.response_template, add_special_tokens=False
)
else:
if not all([item.isinstance(int) for item in args.response_template]):
raise ValueError(
"Response template must be a list of integers if it is not a string."
)
response_generation_tokens = args.response_template

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
Expand All @@ -208,6 +237,9 @@ def train_model(
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
response_generation_tokens=no_bos_or_eos(
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
),
)

model.train()
Expand All @@ -216,6 +248,10 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)

if args.mask_inputs:
print("Masking inputs..")

# Train model
train(
model=model,
Expand All @@ -225,6 +261,10 @@ def train_model(
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
iterate_batches=(
iterate_completion_batches if args.mask_inputs else iterate_batches
),
loss=input_masked_loss if args.mask_inputs else default_loss,
)


Expand Down
6 changes: 6 additions & 0 deletions llms/mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from functools import partial
from typing import List

from transformers import AutoTokenizer

Expand Down Expand Up @@ -340,3 +341,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,
)


def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
Loading