-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from tanganke/develop
merge develop into main
- Loading branch information
Showing
21 changed files
with
568 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
alpaca-cleaned: | ||
_target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset | ||
tokenizer: ??? | ||
path: "yahma/alpaca-cleaned" | ||
split: train | ||
cache_path: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
ultrachat-200k: | ||
_target_: fusion_bench.dataset.ultrachat.load_tokenized_ultrachat_200k | ||
tokenizer: ??? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
defaults: | ||
- loggers: tensorboard_logger | ||
- strategy: llama_peft_fsdp | ||
- _self_ | ||
|
||
_target_: lightning.Fabric | ||
_recursive_: true | ||
# Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. | ||
# The value applies per node. | ||
devices: auto | ||
# The hardware to run on. Possible choices are: | ||
# ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. | ||
# for example: fabric.accelerator=cpu | ||
accelerator: auto | ||
# reference to the precision policy: https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision | ||
precision: bf16-true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_target_: lightning.fabric.strategies.FSDPStrategy | ||
sharding_strategy: FULL_SHARD | ||
state_dict_type: full # Save a single, consolidated checkpoint file | ||
cpu_offload: false | ||
auto_wrap_policy: | ||
_target_: fusion_bench.mixins.lightning_fabric.get_size_based_auto_wrap_policy | ||
activation_checkpointing_policy: ${.auto_wrap_policy} | ||
# limit_all_gathers: true | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
_target_: fusion_bench.modelpool.CausalLMPool | ||
|
||
pretrained_model_name_or_path: meta-llama/Llama-3-1B-Instruct | ||
|
||
models: | ||
_pretrained_: | ||
_target_: transformers.AutoModelForCausalLM.from_pretrained | ||
pretrained_model_name_or_path: ${...pretrained_model_name_or_path} | ||
torch_dtype: bfloat16 | ||
|
||
tokenizer: | ||
_target_: transformers.AutoTokenizer.from_pretrained | ||
pretrained_model_name_or_path: ${..pretrained_model_name_or_path} | ||
|
||
train_datasets: | ||
ultrachat-200k: | ||
_target_: fusion_bench.dataset.llama.ultrachat.load_tokenized_ultrachat_200k | ||
tokenizer: ${...tokenizer} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
_target_: fusion_bench.modelpool.SeqenceClassificationModelPool | ||
|
||
pretrained_model_name_or_path: fusion-bench/Llama-3.2-1B-Instruct_Bradly-Terry-RM_Preference-700k | ||
|
||
models: | ||
_pretrained_: | ||
_target_: transformers.AutoModelForSequenceClassification.from_pretrained | ||
pretrained_model_name_or_path: ${...pretrained_model_name_or_path} | ||
torch_dtype: bfloat16 | ||
|
||
tokenizer: | ||
_target_: transformers.AutoTokenizer.from_pretrained | ||
pretrained_model_name_or_path: ${..pretrained_model_name_or_path} | ||
pad_token: <|end_of_text|> # do not use eos token (<|eos_id|>) as padding token because it is used as the end of each content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
_target_: fusion_bench.taskpool.llama.reward_model.RewardModelEvaluationTaskPool | ||
|
||
test_datasets: | ||
preference_700k: | ||
_target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf | ||
tokenizer: ${...tokenizer} | ||
path: hendrydong/preference_700K | ||
split: train | ||
cache_path: null | ||
|
||
dataloader_kwargs: | ||
shuffle: False | ||
batch_size: 16 | ||
|
||
tokenizer: ${..modelpool.tokenizer} | ||
|
||
max_num_samples: 1000 | ||
seed: 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import os | ||
from copy import deepcopy | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from datasets import Dataset, load_dataset, load_from_disk | ||
from lightning.fabric.utilities import rank_zero_only | ||
from tqdm.auto import tqdm | ||
|
||
from fusion_bench.utils import timeit_context | ||
|
||
if TYPE_CHECKING: | ||
from transformers import PreTrainedTokenizer | ||
|
||
|
||
def load_tokenized_stanford_shp_for_rlhf( | ||
tokenizer: "PreTrainedTokenizer", | ||
path: str = "stanfordnlp/SHP", | ||
split: str = "train", | ||
num_proc: int = 8, | ||
cache_path: Optional[str] = None, | ||
): | ||
if cache_path is not None and os.path.isdir(cache_path): | ||
dataset = load_from_disk(cache_path) | ||
return dataset | ||
|
||
dataset = load_dataset(path, split=split) | ||
|
||
def tokenize(sample): | ||
""" | ||
- history: the post title concatented to the post body (string) | ||
- human_ref_A: text of comment A (string) | ||
- human_ref_B: text of comment B (string) | ||
- labels: the preference label -- it is 1 if A is preferred to B; 0 if B is preferred to A. This was randomized such that the label distribution is roughly 50/50. (integer) | ||
""" | ||
# Create a conversation with the post title and body, followed by comments | ||
conversation = [{"role": "user", "content": sample["history"]}] | ||
if sample["labels"] == 0: | ||
sample["chosen"] = deepcopy(conversation).append( | ||
{"role": "assistant", "content": sample["human_ref_B"]} | ||
) | ||
sample["rejected"] = deepcopy(conversation).append( | ||
{"role": "assistant", "content": sample["human_ref_A"]} | ||
) | ||
else: | ||
sample["chosen"] = deepcopy(conversation).append( | ||
{"role": "assistant", "content": sample["human_ref_A"]} | ||
) | ||
sample["rejected"] = deepcopy(conversation).append( | ||
{"role": "assistant", "content": sample["human_ref_B"]} | ||
) | ||
|
||
# apply chat template | ||
sample["chosen_chat"] = tokenizer.apply_chat_template( | ||
sample["chosen"], tokenize=False, add_generation_prompt=False | ||
) | ||
sample["rejected_chat"] = tokenizer.apply_chat_template( | ||
sample["rejected"], tokenize=False, add_generation_prompt=False | ||
) | ||
|
||
# tokenize the conversation | ||
tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True) | ||
tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True) | ||
|
||
# Ensure that the chosen response does not contain an EOS token | ||
sample["chosen_input_ids"] = tokenized_pos["input_ids"] | ||
sample["chosen_attention_mask"] = tokenized_pos["attention_mask"] | ||
assert ( | ||
tokenizer.eos_token_id not in tokenized_pos["input_ids"][:-1] | ||
), f"Prompt contains EOS token: {sample['positive']}" | ||
if sample["chosen_input_ids"][-1] != tokenizer.eos_token_id: | ||
sample["chosen_input_ids"].append(tokenizer.eos_token_id) | ||
sample["chosen_attention_mask"].append(1) | ||
|
||
sample["rejected_input_ids"] = tokenized_neg["input_ids"] | ||
sample["rejected_attention_mask"] = tokenized_neg["attention_mask"] | ||
# Ensure that the rejected response does not contain an EOS token | ||
assert ( | ||
tokenizer.eos_token_id not in tokenized_neg["input_ids"][:-1] | ||
), f"Prompt contains EOS token: {sample['rejected']}" | ||
if sample["rejected_input_ids"][-1] != tokenizer.eos_token_id: | ||
sample["rejected_input_ids"].append(tokenizer.eos_token_id) | ||
sample["rejected_attention_mask"].append(1) | ||
|
||
dataset = dataset.map(tokenize, num_proc=num_proc) | ||
|
||
if cache_path is not None and rank_zero_only.rank == 0: | ||
dataset.save_to_disk(cache_path) | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from datasets import Dataset, load_dataset, load_from_disk | ||
from lightning.fabric.utilities import rank_zero_only | ||
from tqdm.auto import tqdm | ||
|
||
from fusion_bench.utils import timeit_context | ||
|
||
if TYPE_CHECKING: | ||
from transformers import PreTrainedTokenizer | ||
|
||
|
||
def load_tokenized_ultrachat_200k( | ||
tokenizer: "PreTrainedTokenizer", | ||
path: str = "HuggingFaceH4/ultrachat_200k", | ||
split: str = "train_sft", | ||
num_proc: int = 8, | ||
cache_path: Optional[str] = None, | ||
): | ||
R""" | ||
Load and tokenized Ultrachat 200k dataset for Bradley-Terry ranking model. | ||
The returned dataset contains the following fields: | ||
- input_ids: The input token ids for the winner. | ||
- attention_mask: The attention mask for the winner. | ||
""" | ||
if cache_path is not None and os.path.exists(cache_path): | ||
dataset = load_from_disk(cache_path) | ||
return dataset | ||
|
||
dataset = load_dataset(path, split=split) | ||
|
||
def tokenize(sample): | ||
|
||
# ? is it necessary to `.replace(tokenizer.bos_token, "")`? | ||
sample["input_ids"] = tokenizer.apply_chat_template( | ||
sample["messages"], tokenize=True, add_generation_prompt=False | ||
) | ||
sample["attention_mask"] = [1] * len(sample["input_ids"]) | ||
|
||
return sample | ||
|
||
dataset = dataset.map(tokenize, num_proc=num_proc) | ||
|
||
if cache_path is not None and rank_zero_only.rank == 0: | ||
dataset.save_to_disk(cache_path) | ||
return dataset | ||
|
||
|
||
if __name__ == "__main__": | ||
# Example usage and testing | ||
from transformers import AutoTokenizer | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") | ||
dataset = load_tokenized_ultrachat_200k(tokenizer) | ||
print(dataset) |
Empty file.
Oops, something went wrong.