Training LLMs with Quantized LoRA + FSDP.
Read our announcement blog post.
You should treat this script as an alpha/preview release. If you’re not comfortable with testing and debugging models, we’d suggest holding off for a few months while the community more fully tests the approach.
The following steps should work (tested on Cuda 11.7, 11.8 and 12.1):
pip install llama-recipes fastcore --extra-index-url https://download.pytorch.org/whl/test/cu118
as an easy way to get most dependencies (replace 118 with your desired Cuda version)- Install bitsandbytes
pip install bitsandbytes>=0.43.0
- Run
huggingface-cli login
(to access Llama 2) - Optional Libraries:
- HQQ quantization: follow the HQQ installation instructions. Our training script uses
HQQBackend.ATEN_BACKPROP
, so also make sure to build the custom kernelscd hqq/kernels && python setup_cuda.py install
. - Weights and Biases logging:
pip install wandb
- HQQ quantization: follow the HQQ installation instructions. Our training script uses
- Pytorch >= 2.2 is recommended to make use of the native flash-attention 2 kernel.
Once installed, run cd fsdp_qlora
and then run the following command to begin finetuning Llama-2 70B on Alpaca at a maximum sequence length of 2048 tokens.
python train.py \
--model_name meta-llama/Llama-2-70b-hf \
--batch_size 2 \
--context_length 2048 \
--precision bf16 \
--train_type qlora \
--use_gradient_checkpointing true \
--use_cpu_offload true \
--dataset alpaca \
--reentrant_checkpointing true \
For quantization we support HQQ and bitsandbytes. We're currently doing benchmarking to help you decide which to use. If you do use bitsandbytes, be sure to pass --reentrant_checkpointing True
to avoid triggering a bug in bitsandbytes which results in high memory usage (a fix is in progress).
Full params fine-tuning.
export CUDA_VISIBLE_DEVICES=4,5 # optionally set devices
python train.py \
--world_size 2 \ # optional, on a single machine will be set automatically
--master_port 12356 \ # optional, defaults to 12355
--model_name meta-llama/Llama-2-7b-hf \
--gradient_accumulation_steps 4 \
--batch_size 8 \
--context_length 512 \
--precision bf16 \
--train_type full \
--use_gradient_checkpointing true \
--use_cpu_offload false \
--use_activation_cpu_offload false \
--log_to wandb \
--dataset alpaca \
LoRA fine-tuning using HF PEFT library.
- --train_type full \
+ --train_type lora \
LoRA fine-tuning using a custom LoRA module.
- --train_type full \
+ --train_type custom_lora \
4-bit quantized LoRA fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization and HF PEFT library.
- --train_type full \
+ --train_type qlora \
+ --reentrant_checkpointing true \
4-bit quantized LoRA fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization and a custom LoRA module.
- --train_type full \
+ --train_type custom_qlora \
+ --reentrant_checkpointing true \
4-bit quantized LoRA fine-tuning using HQQ library and a custom LoRA module.
- --train_type full \
+ --train_type hqq_lora \
During quantized LoRA training we use a custom quantization and loading code to avoid loading the entire model into GPU memory before sharding it across GPUs. This is the default behavior of our training script when any of the following training options "qlora", "custom_qlora", "hqq_lora"
is used. Other training options are already optimized for low memory loading to their best extent.
We load the weights iteratively, quantize them on the GPU and place them back to CPU or meta device (based on their rank) concurrently a few layers at a time. We do this across all GPUs to initialize the quantization parameters, such as zero and scale, while using sync_module_states=True
to sync the model parameters and buffers across all GPUs during FSDP initialization.
This will cast all the model parameters to torch.bfloat16
before training and won't use FSDP mixed precision. As a result, sharded and unsharded params will be stored in bf16, forward and backward passes will be done in bf16, and gradient reduction and updates will be done in bf16.
This will cast all the model parameters to torch.float32
before training and won't use FSDP mixed precision. As a result, sharded and unsharded params will be stored in fp32, forward and backward passes will be done in fp32, and gradient reduction and updates will be done in fp32.
This will cast all the model parameters to torch.float32
before training and will use FSDP mixed precision with
mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
As a results, sharded and unsharded params will be stored in fp32. It will use autocast(torch.float16)
for forward and backward passes, and autocast(torch.float16)
for gradient reduction and updates.
This will cast all the model parameters to torch.float32
before training and will use FSDP mixed precision with
mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
As a results, sharded and unsharded params will be stored in fp32. It will use autocast(torch.bfloat16)
for forward and backward passes, and autocast(torch.bfloat16)
for gradient reduction and updates.
This will cast all the model parameters to torch.bfloat16
before training but will keep the buffers in torch.float32
and will use FSDP mixed precision with
mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
As a results, sharded and unsharded params will be stored in bf16. It will use autocast(torch.bfloat16)
for forward and backward passes, and autocast(torch.bfloat16)
for gradient reduction and updates. Buffers and only eligible operations in autocast will be performed in bf16.
This option is important for RoPE layer which gives incorrect results when cast to lower precision especially with longer context lengths.
hf_train.py
uses TRL's SFTTrainer for a comparison run. To match with our script, modify the dataloading code to train on everything (not just completions) and then run train.py --train_type qlora --dataset guanaco --batch_size 8 --lr_scheduler cosine --log_to wandb --save_model True --output_dir guanaco_7B --gradient_accumulation_steps 2 --lr 2e-4
. The SFTTrainer version has to run with a lower batch size (4 vs 8) so we only do 2 gradient accumulation steps vs 4 in the QLoRA+FSDP version.
If you specify --save_model True
the adapter layers will be saved as a state dict. To convert to the regular Hugging Face format and upload to the hub, see: Converting the State Dict.ipynb
If "custom_qlora", "hqq_lora"
training options are used, then only the trainable LoRA parameters will be saved. Before inference, you need to load and quantize the base model again, and separately load the saved LoRA parameters.
You can alternatively test to see if merging base model weights and trained LoRA weights and then quantizing them performs similar to keeping the parameters separately as done during training. To make use of torch.compile
with HQQ, see mobiusml/hqq#18.
While QLoRA finetuning works with FSDP, there are some rough edges to be aware of with this alpha release and our example script.
First, the current release of Transformer AutoModel.from_pretrained
cannot be used to load models into quantized weights, as it does not support the new quant_storage or quantization flag. Loading pretrained models requires writing or using custom model loading code. We provide an example of how to load and quantize a QLoRA model for finetuning in our demo script.
We are actively working with Hugging Face to resolve this incompatibility in future Transformers and PEFT releases.
Secpnd, while FSDP’s Mixed Precision works with QLoRA, practitioners need to be careful to set the MixedPrecision.param_type
to match the Linear4Bit.quant_storage
dtype. Otherwise, FSDP’s Mixed Precision could cast the quantized weights to a different precision, essentially turning them into random weights. Our example script shows how to avoid this potential pitfall, and we will be happy to assist model training libraries in correctly exposing FSDP’s Mixed Precision options to users when training with QLoRA
# BnB QLoRA
export CUDA_VISIBLE_DEVICES=4,5,6,7
python train.py \
--world_size 4 \
--master_port 12356 \
--model_name meta-llama/Llama-2-70b-hf \
--gradient_accumulation_steps 4 \
--batch_size 2 \
--context_length 512 \
--precision bf16_buffers_autocast \
--train_type custom_qlora \
--use_gradient_checkpointing true \
--reentrant_checkpointing true
--use_cpu_offload false \
--log_to stdout \
--dataset alpaca \
# HQQ QLoRA
export CUDA_VISIBLE_DEVICES=4,5,6,7
python train.py \
--world_size 4 \
--master_port 12356 \
--model_name meta-llama/Llama-2-70b-hf \
--gradient_accumulation_steps 4 \
--batch_size 2 \
--context_length 512 \
--precision bf16_buffers_autocast \
--train_type hqq_lora \
--use_gradient_checkpointing true \
--use_cpu_offload false \
--log_to stdout \
--dataset alpaca \
Note: For large batch size or long context training HQQ LoRA is a bit more memory efficient compared to BnB LoRA with re-entrant checkpointing. So if you are running into OOM issues, try using HQQ LoRA.
See fsdp_multi_node.sh
for an example training script using multi-node training with SLURM.