Skip to content

Commit

Permalink
add peft finetune for llm (#38)
Browse files Browse the repository at this point in the history
* update fineune_sft for llama

* add peft finetune method
  • Loading branch information
tanganke authored Nov 25, 2024
1 parent cc015c6 commit 3d70c73
Show file tree
Hide file tree
Showing 6 changed files with 539 additions and 13 deletions.
18 changes: 18 additions & 0 deletions config/fabric/llama_ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
defaults:
- loggers: tensorboard_logger
- _self_

_target_: lightning.Fabric
_recursive_: true
# Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
# The value applies per node.
devices: auto
# Strategy for how to run across multiple devices. Possible choices are:
# ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
strategy: ddp
# The hardware to run on. Possible choices are:
# ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
# for example: fabric.accelerator=cpu
accelerator: auto
# reference to the precision policy: https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
precision: bf16-true
55 changes: 55 additions & 0 deletions config/method/lm_finetune/peftfinetune_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_target_: ttt.method.FullFinetuneSFT
_recursive_: False

optimizer:
_target_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 5e-5

lr_scheduler:
_target_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 5
num_training_steps: _T_max_ # this will be replaced by the expected number of training steps

dataloader_kwargs:
# per-gpu batch size
batch_size: 1
num_workers: 0
pin_memory: True

peft_config:
_target_: peft.LoraConfig
task_type: peft.TaskType.CAUSAL_LM
target_modules:
- query
- value
r: 16
lora_alpha: 16
lora_dropout: 0
bais: none

adapter_name: default
# whether to merge and unload the adapter after training
merge_and_unload: false

# Training hyperparameters
# if max_epochs=-1, max_steps will be used to determine the number of training steps
max_epochs: 3
max_steps: -1
max_steps_per_epoch: -1
accumulate_grad_batches: 1
lr_scheduler_interval: step
lr_scheduler_frequency: 1
# Checkpointing may be done by epoch or step, and at the end of training
# `checkpoint_save_interval` can be 'epoch' or 'step'
checkpoint_save_interval: epoch
checkpoint_save_frequency: 1
# Whether to use gradient clipping, and if so, the value and algorithm
gradient_clip_val: null
gradient_clip_algorithm: norm
save_optimizer_state: false
# save_full_model must be true when using shared FSDP
save_full_model: false
# Path to checkpoint to load from, used for resuming training
ckpt_path: null
6 changes: 3 additions & 3 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"dummy": ["DummyAlgorithm"],
# single task learning (fine-tuning)
"classification": ["ImageClassificationFineTuningForCLIP"],
"lm_finetune": ["FullFinetuneSFT"],
"lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT"],
# analysis
"analysis": ["TaskVectorCosSimilarity", "TaskVectorViolinPlot"],
# model ensemble methods
Expand Down Expand Up @@ -89,8 +89,8 @@


if TYPE_CHECKING:
from .adamerging import *
from .ada_svd import AdaSVDMergingForCLIPVisionModel
from .adamerging import *
from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
from .classification import ImageClassificationFineTuningForCLIP
Expand All @@ -116,6 +116,7 @@
SimpleAverageForLlama,
TaskArithmeticForLlama,
)
from .lm_finetune import *
from .mixture_of_experts import (
MixtralForCausalLMMergingAlgorithm,
MixtralForCausalLMUpscalingAlgorithm,
Expand Down Expand Up @@ -152,7 +153,6 @@
from .ties_merging import TiesMergingAlgorithm
from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
from .lm_finetune import *

else:
sys.modules[__name__] = LazyImporter(
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/lm_finetune/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .fullfinetune_sft import FullFinetuneSFT
from .peftfinetune_sft import PeftFinetuneSFT
44 changes: 34 additions & 10 deletions fusion_bench/method/lm_finetune/fullfinetune_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union

import lightning as L
import omegaconf
import torch
from lightning.fabric.strategies.fsdp import FSDPStrategy
from lightning.fabric.utilities import rank_zero_only
from omegaconf import DictConfig
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
from tqdm.auto import tqdm
from typing_extensions import TYPE_CHECKING, override

from fusion_bench import BaseAlgorithm, BaseModelPool
from fusion_bench.dataset.llama.collate import padded_collate_sft
from fusion_bench.mixins import LightningFabricMixin
from fusion_bench.modelpool import CausalLMPool
from fusion_bench.utils import instantiate
from fusion_bench.utils.dtype import get_dtype
from lightning.fabric.utilities import rank_zero_only
from omegaconf import DictConfig
from torch import nn
from tqdm.auto import tqdm
from typing_extensions import TYPE_CHECKING, override
from lightning.fabric.strategies.fsdp import FSDPStrategy
import lightning as L

if TYPE_CHECKING:
from lightning.fabric.wrappers import (
Expand All @@ -35,7 +37,7 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):

model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
train_dataloader: Union[torch.utils.data.DataLoader, "_FabricDataLoader"]
train_dataloader: Union[DataLoader, "_FabricDataLoader"]
lr_scheduler: torch.optim.lr_scheduler.LRScheduler

def __init__(
Expand All @@ -58,6 +60,27 @@ def __init__(
ckpt_path: Optional[str] = None,
**kwargs,
):
"""
Class for full finetuning of a language model on given SFT datasets.
Args:
optimizer(DictConfig): Configuration for the optimizer.
lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
"""
self._optimizer = optimizer
self._lr_scheduler = lr_scheduler
self.dataloader_kwargs = dataloader_kwargs
Expand All @@ -80,6 +103,7 @@ def run(self, modelpool: CausalLMPool):
self.modelpool = modelpool
self.setup()
self.train()
return self.model

def setup_model(self):
model = self.modelpool.load_pretrained_model()
Expand Down Expand Up @@ -136,12 +160,12 @@ def setup_data(self):
for dataset_name in modelpool.train_dataset_names
]
if len(train_datasets) > 1:
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
train_dataset = ConcatDataset(train_datasets)
else:
train_dataset = train_datasets[0]

self.train_dataset = train_dataset
self.train_dataloader = torch.utils.data.DataLoader(
self.train_dataloader = DataLoader(
train_dataset,
**self.dataloader_kwargs,
shuffle=True,
Expand Down
Loading

0 comments on commit 3d70c73

Please sign in to comment.