Releases: Lightning-AI/pytorch-lightning
Lightning 2.1: Train Bigger, Better, Faster
Lightning AI is excited to announce the release of Lightning 2.1 ⚡ It's the culmination of work from 79 contributors who have worked on features, bug-fixes, and documentation for a total of over 750+ commits since v2.0.
The theme of 2.1 is "bigger, better, faster": Bigger because training large multi-billion parameter models has gotten even more efficient thanks to FSDP, efficient initialization and sharded checkpointing improvements, better because it's easier than ever to scale models without making substantial code changes or installing third-party packages and faster because it leverages the latest hardware features to speed up training in low-bit precision thanks to new precision plugins like bitsandbytes and transformer engine.
And of course, as the name implies, this release fully leverages the latest features in PyTorch 2.1 🎉
Highlights
Improvements To Large-Scale Training With FSDP
The FSDP strategy for training large billion-parameter models gets substantial improvements and new features in Lightning 2.1, both in Trainer and Fabric (in case you didn't know, Fabric is the latest addition to the Lightning family of tools to scale models without the boilerplate code).
FSDP is now more user-friendly to configure, has memory management and speed improvements, and we have a brand new end-to-end user guide with best practices (Trainer, Fabric).
Efficient Saving and Loading of Large Checkpoints
When training large billion-parameter models with FSDP, saving and resuming training, or even just loading model parameters for finetuning can be challenging, as users are are often plagued by out-of-memory errors and speed bottlenecks.
In 2.1, we made several improvements. Starting with saving checkpoints, we added support for distributed/sharded checkpoints, enabled through the setting state_dict_type
in the strategy (#18364, #18358):
Trainer:
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
# Default used by the strategy
strategy = FSDPStrategy(state_dict_type="full")
# Enable saving distributed checkpoints
strategy = FSDPStrategy(state_dict_type="sharded")
trainer = L.Trainer(strategy=strategy, ...)
Fabric:
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
# Saving distributed checkpoints is the default
strategy = FSDPStrategy(state_dict_type="sharded")
# Save consolidated (single file) checkpoints
strategy = FSDPStrategy(state_dict_type="full")
fabric = L.Fabric(strategy=strategy, ...)
Distributed checkpoints are the fastest and most memory efficient way to save the state of very large models.
The distributed checkpoint format also makes it efficient to load these checkpoints back for resuming training in parallel, and it reduces the impact on CPU memory usage significantly. Furthermore, we've also introduced lazy-loading for non-distributed checkpoints (#18150, #18379), which greatly reduces the impact on CPU memory usage when loading a consolidated (single-file) checkpoint (e.g. for finetuning). Learn more about these features in our FSDP guides (Trainer, Fabric).
Fast and Memory-Optimized Initialization
A major challenge that users face when working with large models such as LLMs is dealing with the extreme memory requirements. Even something as simple as instantiating a model becomes non-trivial if the model is so large it won't fit in a single GPU or even a single machine. In Lightning 2.1, we are introducing empty-weights initialization through the Fabric.init_module()
(#17462, #17627) and Trainer.init_module()
/LightningModule.configure_model()
(#18004, #18004, #18385) methods:
Trainer:
import lightning as L
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Delay initialization of model to `configure_model()`
def configure_model(self):
# Model initialized in correct precision and weights on meta-device
self.model = ...
...
trainer = L.Trainer(strategy="fsdp", ...)
trainer.fit(model)
Fabric:
import lightning as L
fabric = L.Fabric(strategy="fsdp", ...)
# Model initialized in correct precision and weights on meta-device
with fabric.init_module(empty_init=True):
model = ...
# You can also initialize buffers and tensors directly on device and dtype
with fabric.init_tensor():
model.mask.create()
model.kv_cache.create()
x = torch.randn(4, 128)
# Materialization and sharding of model happens inside here
model = fabric.setup(model)
Read more about this new feature and its other benefits in our docs (Trainer, Fabric).
User-Friendly Configuration
We made it super easy to configure the sharding- and activation-checkpointing policy when you want to auto-wrap particular layers of your model for advanced control (#18045, #18084).
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
- strategy = FSDPStrategy(auto_wrap_policy=ModuleWrapPolicy({MyTransformerBlock}))
+ strategy = FSDPStrategy(auto_wrap_policy={MyTransformerBlock})
trainer = L.Trainer(strategy=strategy, ...)
Furthermore, the sharding strategy can now be conveniently set with a string value (#18087):
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
- from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
- strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
+ strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
trainer = L.Trainer(strategy=strategy, ...)
You no longer need to remember the long PyTorch imports! Fabric also supports all these improvements shown above.
True Half-Precision
Lightning now supports true half-precision for training and inference with all built-in strategies (#18193, #18217, #18213, #18219). With this setting, the memory required to store the model weights is only half of what is normally needed when running with float32. In addition, you get the same speed benefits as mixed precision training (precision="16-mixed"
) has:
import lightning as L
# default
trainer = L.Trainer(precision="32-true")
# train with model weights in `torch.float16`
trainer = L.Trainer(precision="16-true")
# train with model weights in `torch.bfloat16`
# (if hardware supports it)
trainer = L.Trainer(precision="bf16-true")
The same settings are also available in Fabric! We recommend to try bfloat16 training (precision="bf16-true"
) as it is often more numerically stable than regular 16-bit precision (`precisi...
Feature teaser
🐰
Hotfix for Conda package
2.0.9.post0 releasing 2.0.9.post0
Weekly patch release
App
Fixed
- Replace LightningClient with import from lightning_cloud (#18544)
Fabric
Fixed
- Fixed an issue causing the
_FabricOptimizer.state
to remain outdated after loading withload_state_dict
(#18488)
PyTorch
Fixed
- Fixed an issue that wouldn't prevent the user to set the
log_model
parameter inWandbLogger
via the LightningCLI (#18458) - Fixed the display of
v_num
in the progress bar when running withTrainer(fast_dev_run=True)
(#18491) - Fixed
UnboundLocalError
when running withpython -O
(#18496) - Fixed visual glitch with the TQDM progress bar leaving the validation bar incomplete before switching back to the training display (#18503)
- Fixed false positive warning about logging interval when running with
Trainer(fast_dev_run=True)
(#18550)
Contributors
@awaelchli, @Borda, @justusschock, @SebastianGer
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Weekly patch release
App
Changed
Fixed
- refactor path to root preventing circular import (#18357)
Fabric
Changed
- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process (#16966)
Fixed
- Fixed model parameters getting shared between processes when running with
strategy="ddp_spawn"
andaccelerator="cpu"
; this has a necessary memory impact, as parameters are replicated for each process now (#18238) - Removed false positive warning when using
fabric.no_backward_sync
with XLA strategies (#17761) - Fixed issue where Fabric would not initialize the global rank, world size, and rank-zero-only rank after initialization and before launch (#16966)
- Fixed FSDP full-precision
param_dtype
training (16-mixed
,bf16-mixed
and32-true
configurations) to avoid FSDP assertion errors with PyTorch < 2.0 (#18278)
PyTorch
Changed
- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process (#16966)
- Fix inefficiency in rich progress bar (#18369)
Fixed
- Fixed FSDP full-precision
param_dtype
training (16-mixed
andbf16-mixed
configurations) to avoid FSDP assertion errors with PyTorch < 2.0 (#18278) - Fixed an issue that prevented the use of custom logger classes without an
experiment
property defined (#18093) - Fixed setting the tracking uri in
MLFlowLogger
for logging artifacts to the MLFlow server (#18395) - Fixed redundant
iter()
call to dataloader when checking dataloading configuration (#18415) - Fixed model parameters getting shared between processes when running with
strategy="ddp_spawn"
andaccelerator="cpu"
; this has a necessary memory impact, as parameters are replicated for each process now (#18238) - Properly manage
fetcher.done
withdataloader_iter
(#18376)
Contributors
@awaelchli, @Borda, @carmocca, @quintenroets, @rlizzo, @speediedan, @tchaton
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Weekly patch release
App
Changed
- Removed the top-level import
lightning.pdb
; importlightning.app.pdb
instead (#18177) - Client retries forever (#18065)
Fixed
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning (#18177)
Fabric
Changed
- Disabled the auto-detection of the Kubeflow environment (#18137)
Fixed
- Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory (#18145)
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning (#18177)
- Fixed an issue with
Fabric.all_reduce()
not performing an inplace operation for all backends consistently (#18235)
PyTorch
Added
- Added
LightningOptimizer.refresh()
to update the__dict__
in case the optimizer it wraps has changed its internal state (#18280)
Changed
- Disabled the auto-detection of the Kubeflow environment (#18137))
Fixed
- Fixed a
Missing folder
exception when using a Google Storage URL as adefault_root_dir
(#18088) - Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning (#18177)
- Fixed the gradient unscaling logic if the training step skipped backward (by returning
None
) (#18267) - Ensure that the closure running inside the optimizer step has gradients enabled, even if the optimizer step has it disabled (#18268)
- Fixed an issue that could cause the
LightningOptimizer
wrapper returned byLightningModule.optimizers()
have different internal state than the optimizer it wraps (#18280)
Contributors
@0x404, @awaelchli, @bilelomrani1, @Borda, @ethanwharris, @nisheethlahoti
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Minor patch release
2.0.6
App
- Fixed handling a
None
request in the file orchestration queue (#18111)
Fabric
- Fixed
TensorBoardLogger.log_graph
not unwrapping the_FabricModule
(#17844)
PyTorch
LightningCLI
not saving correctlyseed_everything
whenrun=True
andseed_everything=True
(#18056)- Fixed validation of non-PyTorch LR schedulers in manual optimization mode (#18092)
- Fixed an attribute error for
_FaultTolerantMode
when loading an old checkpoint that pickled the enum (#18094)
Contributors
@awaelchli, @lantiga, @mauvilsa, @shihaoyin
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Minor patch release
App
Added
- plugin: store source app (#17892)
- added colocation identifier (#16796)
- Added exponential backoff to HTTPQueue put (#18013)
- Content for plugins (#17243)
Changed
- Save a reference to created tasks, to avoid tasks disappearing (#17946)
Fabric
Added
- Added validation against misconfigured device selection when using the DeepSpeed strategy (#17952)
Changed
- Avoid info message when loading 0 entry point callbacks (#17990)
Fixed
- Fixed the emission of a false-positive warning when calling a method on the Fabric-wrapped module that accepts no arguments (#17875)
- Fixed check for FSDP's flat parameters in all parameter groups (#17914)
- Fixed automatic step tracking in Fabric's CSVLogger (#17942)
- Fixed an issue causing the
torch.set_float32_matmul_precision
info message to show multiple times (#17960) - Fixed loading model state when
Fabric.load()
is called afterFabric.setup()
(#17997)
PyTorch
Fixed
- Fixed delayed creation of experiment metadata and checkpoint/log dir name when using
WandbLogger
(#17818) - Fixed incorrect parsing of arguments when augmenting exception messages in DDP (#17948)
- Fixed an issue causing the
torch.set_float32_matmul_precision
info message to show multiple times (#17960) - Added missing
map_location
argument for theLightningDataModule.load_from_checkpoint
function (#17950) - Fix support for
neptune-client
(#17939)
Contributors
@anio, @awaelchli, @Borda, @ethanwharris, @lantiga, @nicolai86, @rjarun8, @schmidt-ai, @schuhschuh, @wouterzwerink, @yurijmikhalevich
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Minor patch release
App
Fixed
- bumped several dependencies to address security vulnerabilities.
Fabric
Fixed
- Fixed validation of parameters of
plugins.precision.MixedPrecision
(#17687) - Fixed an issue with HPU imports leading to performance degradation (#17788)
PyTorch
Changed
- Changes to the
NeptuneLogger
(#16761):- It now supports neptune-client 0.16.16 and neptune >=1.0, and we have replaced the
log()
method withappend()
andextend()
. - It now accepts a namespace
Handler
as an alternative toRun
for therun
argument. This means that you can call itNeptuneLogger(run=run["some/namespace"])
to log everything to thesome/namespace/
location of the run.
- It now supports neptune-client 0.16.16 and neptune >=1.0, and we have replaced the
Fixed
- Fixed validation of parameters of
plugins.precision.MixedPrecisionPlugin
(#17687) - Fixed deriving default map location in
LightningModule.load_from_checkpoint
when there is an extra state (#17812)
Contributors
@akreuzer, @awaelchli, @Borda, @jerome-habana, @kshitij12345
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Minor patch release
App
Added
- Added the property
LightningWork.public_ip
that exposes the public IP of theLightningWork
instance (#17742) - Add missing python-multipart dependency (#17244)
Changed
- Made type hints public (#17100)
Fixed
- Fixed
LightningWork.internal_ip
that was mistakenly exposing the public IP instead; now exposes the private/internal IP address (#17742) - Fixed resolution of the latest version in CLI (#17351)
- Fixed property raised instead of returned (#17595)
- Fixed get project (#17617, #17666)
Fabric
Added
- Added support for
Callback
registration through entry points (#17756)
Changed
Fixed
- Fixed computing the next version folder in
CSVLogger
(#17139) - Fixed inconsistent settings for FSDP Precision (#17670)
PyTorch
Changed
- Made type hints public (#17100)
Fixed
CombinedLoader
only starts DataLoader workers when necessary when operating in sequential mode (#17639)- Fixed a potential bug with uploading model checkpoints to Neptune.ai by uploading files from stream (#17430)
- Fixed signature inspection of decorated hooks (#17507)
- The
WandbLogger
no longer flattens dictionaries in the hyperparameters logged to the dashboard (#17574) - Fixed computing the next version folder in
CSVLogger
(#17139) - Fixed a formatting issue when the filename in
ModelCheckpoint
contained metrics that were substrings of each other (#17610) - Fixed
WandbLogger
ignoring theWANDB_PROJECT
environment variable (#16222) - Fixed inconsistent settings for FSDP Precision (#17670)
- Fixed an edge case causing overlapping samples in DDP when no global seed is set (#17713)
- Fallback to module available check for mlflow (#17467)
- Fixed LR finder max val batches (#17636)
- Fixed multithreading checkpoint loading (#17678)
Contributors
@adamjstewart, @AleksanderWWW, @awaelchli, @baskrahmer, @bkiat1123, @Borda, @carmocca, @ethanwharris, @leng-yue, @lightningforever, @manangoel99, @mukhery, @Quasar-Kim, @water-vapor, @yurijmikhalevich
If we forgot someone due to not matching commit email with GitHub account, let us know :]