Unofficial implementation of the Ask-LLM paper 'How to Train Data-Efficient LLMs', arXiv:2402.09668.
pip install nano-askllm
Note: Flan-T5 models cannot tokenize multilingual text properly (e.g. Japanese).
# pip install nano-askllm
# pip install datasets sentencepiece accelerate
from transformers import T5ForConditionalGeneration, T5Tokenizer
from datasets import load_dataset
from nano_askllm import AskLLM
model_id = "google/flan-t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_id)
model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)
llm = AskLLM(tokenizer, model)
batch_size = 2
num_ask = 5
for i in range(num_ask):
datapoints = [item["text"] for item in list(dataset.take(batch_size))]
scores = llm.ask(datapoints)
for score, datapoint in zip(scores.tolist(), datapoints):
text = datapoint[:40].replace("\n", " ")
print(f"score: {score:.4f}\ttext: {text}")
dataset = dataset.skip(batch_size)
This model needs to tweak the prompt template and the yes tokens.
# pip install nano-askllm
# pip install datasets sentencepiece accelerate
# hugginface-cli login
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from nano_askllm import AskLLM
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
dataset = load_dataset("allenai/c4", "ja", split="train", streaming=True)
prompt_template_prefix = "###\n"
prompt_template_postfix = """
###
Does the previous paragraph demarcated within ### and ### contain informative signal for pre-training a large-language model? An informative datapoint should be well-formatted, contain some usable knowledge of the world, and strictly NOT have any harmful, racist, sexist, etc. content.
OPTIONS: yes/no
ANSWER:"""
yes_tokens = ["yes", "Yes", "YES", " yes", " Yes", " YES"]
llm = AskLLM(
tokenizer,
model,
prompt_template_prefix=prompt_template_prefix,
prompt_template_postfix=prompt_template_postfix,
yes_tokens=yes_tokens,
max_tokens=512, # You can increase it up to 8192 for gemma-2b-it.
)
batch_size = 2
num_ask = 5
for i in range(num_ask):
datapoints = [item["text"] for item in list(dataset.take(batch_size))]
scores = llm.ask(datapoints)
for score, datapoint in zip(scores.tolist(), datapoints):
text = datapoint[:40].replace("\n", " ")
print(f"score: {score:.4f}\ttext: {text}")
dataset = dataset.skip(batch_size)
This model needs to tweak the prompt template and the yes tokens.
# pip install nano-askllm
# pip install datasets sentencepiece accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from nano_askllm import AskLLM
model_id = "Rakuten/RakutenAI-7B-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
prompt_template_prefix = "###\n"
prompt_template_postfix = """
###
Does the previous paragraph demarcated within ### and ### contain informative signal for pre-training a large-language model? An informative datapoint should be well-formatted, contain some usable knowledge of the world, and strictly NOT have any harmful, racist, sexist, etc. content.
OPTIONS: yes/no
ANSWER:"""
yes_tokens = ["yes", "Yes"]
llm = AskLLM(
tokenizer,
model,
prompt_template_prefix=prompt_template_prefix,
prompt_template_postfix=prompt_template_postfix,
yes_tokens=yes_tokens,
max_tokens=512, # You can increase it up to 8192 for Mistral-7B-v0.1 based models.
)
batch_size = 2
num_ask = 5
for i in range(num_ask):
datapoints = [item["text"] for item in list(dataset.take(batch_size))]
scores = llm.ask(datapoints)
for score, datapoint in zip(scores.tolist(), datapoints):
text = datapoint[:40].replace("\n", " ")
print(f"score: {score:.4f}\ttext: {text}")
dataset = dataset.skip(batch_size)
If you want to see the debug logs, you can set the logger as follows:
from logging import DEBUG, StreamHandler, getLogger
logger = getLogger("nano_askllm.askllm")
logger.setLevel(DEBUG)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.addHandler(handler)
poetry -V # Poetry (version 1.5.1)
git clone https://github.com/susumuota/nano-askllm.git
cd nano-askllm
poetry install
poetry run pytest -s # run pytest once
poetry run -- ptw -- -s # watch for changes and run pytest
@misc{sachdeva2024train,
title={How to Train Data-Efficient LLMs},
author={Noveen Sachdeva and Benjamin Coleman and Wang-Cheng Kang and Jianmo Ni and Lichan Hong and Ed H. Chi and James Caverlee and Julian McAuley and Derek Zhiyuan Cheng},
year={2024},
eprint={2402.09668},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
MIT License. See LICENSE for details.
- Add Colab notebook
- Add examples using Hugging Face Datasets