Skip to content

Commit

Permalink
release v0.9.0 (real)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Sep 8, 2024
1 parent 653fe70 commit 90d6df6
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 53 deletions.
4 changes: 3 additions & 1 deletion .env.local
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ WANDB_DISABLED=
WANDB_PROJECT=huggingface
WANDB_API_KEY=
# gradio ui
GRADIO_SHARE=0
GRADIO_SHARE=False
GRADIO_SERVER_NAME=0.0.0.0
GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH=
# setup
ENABLE_SHORT_CONSOLE=1
# reserved (do not use)
LLAMABOARD_ENABLED=
LLAMABOARD_WORKDIR=
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
Expand Down
3 changes: 3 additions & 0 deletions scripts/cal_mfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def calculate_mfu(
"dataset": "c4_demo",
"cutoff_len": seq_length,
"output_dir": os.path.join("saves", "test_mfu"),
"logging_strategy": "no",
"save_strategy": "no",
"save_only_model": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": num_steps,
Expand Down
15 changes: 12 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,34 @@

import os
import re
from typing import List

from setuptools import find_packages, setup


def get_version():
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version


def get_requires():
def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines


def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")

return console_scripts


extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
Expand Down Expand Up @@ -72,7 +81,7 @@ def main():
python_requires=">=3.8.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def register_model_group(

register_model_group(
models={
"MiniCPM3-4B": {
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
},
Expand Down
70 changes: 24 additions & 46 deletions src/llamafactory/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,38 +96,45 @@ def fix_valuehead_checkpoint(


class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""

@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
)


class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""

def __init__(self, processor: "ProcessorMixin") -> None:
r"""
Initializes a callback for saving the processor.
"""
self.processor = processor

@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(self.processor, "image_processor").save_pretrained(output_dir)

@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)


class PissaConvertCallback(TrainerCallback):
r"""
Initializes a callback for converting the PiSSA adapter to a normal one.
A callback for converting the PiSSA adapter to a normal one.
"""

@override
Expand All @@ -147,9 +154,6 @@ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", contr

@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
Expand Down Expand Up @@ -177,21 +181,22 @@ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control


class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""

def __init__(self) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
# Progress
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
# Status
self.aborted = False
self.do_train = False
""" Web UI """
# Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
Expand Down Expand Up @@ -233,9 +238,6 @@ def _close_thread_pool(self) -> None:

@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
Expand All @@ -246,60 +248,39 @@ def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control:

@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)

@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()

@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True

@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True

@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()

@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()

@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return

Expand Down Expand Up @@ -342,9 +323,6 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return

Expand Down

0 comments on commit 90d6df6

Please sign in to comment.