-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement fast checkpoint path #127
Open
fegin
wants to merge
65
commits into
main
Choose a base branch
from
chienchin/fast_checkpoint
base: main
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Adds option to do torch profile tracing via: --run_profiler (T/F) --profile_folder (str) Traces are saved out with rank_X as part of the trace name. <img width="1711" alt="rank_named_traces" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/6eb3c3e0-6034-4d1f-8ea8-f43988755714"> Implemented as context wrapper around the main training loop.
This commit adds the parallelism scaffolding, and add FSDP + plain AC to apply to llama
This PR adds torch.compile to torchtrain. 1 - Control of the compile option is added to the main config toml. 2 - A config dir with config utils has been made to centralize config loading (get_config() returns the project config file). 3 - Profile.py and train.py have been updated to both load and use the get_config() function. 4 - If use_compile option is on (default = true), then torch.compile is run for the model and logged that it's running for user info. Testing: verifed torch compile on/off, with and without profiling. General comments: I did not add the same compiler option for args yet as thinking better to make a master class that handles args/configs to produce the final settings for the user. Also named the config folder tt_config (torch train config) to try and avoid confusion with other config dirs, but could also just revert to generic config. Used ruff for formatting and linting.
got distracted with merge conflicts and had renamed config folder to 'train_configs', but did not update the default loading path. This quick PR fixes that to ensure no break on running.
Make dp_degree configurable (=-1 consumes the rest of the mesh not used by PP/SP, =1 disables, >1 specifies requested data paralell size. Local tests showed these outputs: with --dp_degree=-1 and pp, sp = 1 Building 1-D device mesh with ['dp'], [8] with sp=4 Building 2-D device mesh with ['dp', 'sp'], [2, 4] with sp=2, pp=2 Building 3-D device mesh with ['dp', 'sp', 'pp'], [2, 2, 2] with sp=2, pp=3 (WS=8) AssertionError: Invalid parallel dims: dp(1) * sp(2) * pp(3) != WORLD_SIZE8 with sp=2, pp=3 (WS=6) Building 2-D device mesh with ['sp', 'pp'], [2, 3]
ghstack-source-id: 4cbd0b99d6c4a2454ce835dd8590353a55440882 Pull Request resolved: #25
ghstack-source-id: c5c3fe80db923fcb4630f942432f8bec389924c2 Pull Request resolved: #28
this PR adds a linear lr scheduler and includes some automation based on current best practices: a - takes user lr provided in args as lr_max, and computes final min_lr for the decay schedule based on lr / 10, per chinchilla paper. (i.e. total decay will be one order of magnitude). b - computes an automated linear warmup schedule of 10% total iters as warmup, with min warmup of 2 steps. c - computes a linear decay schedule after warmup, declining from lr_max to lr_min over the end of warmup to end of training. (per Aarons latest paper, linear is preferred schedule). d - I updated learning rate to 8e-4, in order to provide more visible per iter results to the user assuming debugModel. LR scheduling produces much improved loss curve: <img width="1052" alt="Screenshot 2024-01-28 at 6 39 34 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/667e8520-809f-419e-bfdd-c3bb8f82ff95"> I added two log prints - the warmup schedule as one line, and then a step and current lr at each iter. Both could be disabled if too much info.
note that you should install the pre-commit tool from `dev-requirements.txt` and then you can invoke `pre-commit` from the CLI to get the same linters/formatters applied locally as would be run under CI.
This is so that we can improve our training performance TODO: this is still loading the same data on each rank, which is semantically wrong, we should load different chunk of data on each data parallel rank ghstack-source-id: 23b2b301d3ebb38c016efb21fbbde87bdf1772ef Pull Request resolved: #31
Make it easier to chop Transformer into pieces for PP
Summary: This PR enable checkpointing. The PR only enables checkpointing in the local storages. Only when DCP enables automatic storage detection can this checkpoint manager support remote storages. This PR didn't checkpoint dataloader. Test Plan: Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran ./run_llama_train.sh twice. The first run ran through all 100 steps and the checkpoints were saved. The second run loaded the checkpoint back and detected the saved step count is 100. No training was done for the second step.
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634 Pull Request resolved: #32
ghstack-source-id: 08d335e3151097a273742be7cab615a75015d4dd Pull Request resolved: #49
allows convenient switching of args w/o editing .sh file e.g. `LOG_RANK=1,2 ./run_llama_train.sh` `NGPU=2 SP=2 ./run_llama_train.sh`
Some modules used `args: ModelArgs`, others `params: ModelArgs`, and others `model_args: ModelArgs`. This PR normalizes everything to use `model_args: ModelArgs` for consistency. (`params` might be confused with `nn.Parameter`s, and `model_args` was more explicit than `args`.) **Test Plan** ``` ./run_llama_train.sh ``` <details> <summary> Output </summary> ``` + TRAINER_DIR=/home/andgu/local/torchtrain + MODEL=debugmodel + NGPU=8 + PP=1 + SP=1 + DP=-1 + LOG_RANK=0 + CHECKPOINT_FOLDER= + CHECKPOINT_INTERVAL=5 + torchrun --nproc_per_node=8 --local-ranks-filter 0 --role rank --tee 3 train.py --steps 10 --compile --pp_degree 1 --sp_degree 1 --dp_degree -1 [2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] [2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] ***************************************** [2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. [2024-02-13 09:53:31,345] torch.distributed.run: [WARNING] ***************************************** [rank0]:2024-02-13 09:53:33,644 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [8] [rank0]:2024-02-13 09:53:36,955 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank0]:2024-02-13 09:53:36,955 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank0]:/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 [rank0]: warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" [rank0]:2024-02-13 09:53:41,571 - root - INFO - Applied FSDP to the model... [rank0]:2024-02-13 09:53:41,572 - root - INFO - Gradient scaling not enabled. [rank0]:2024-02-13 09:53:41,572 - root - INFO - Compiling model llama with torch.compile... [rank0]:2024-02-13 09:53:43,892 - root - INFO - Profiling active. Traces will be saved at ./torchtrain/outputs/profiling/traces [rank0]:NCCL version 2.19.3+cuda12.0 [rank0]:[rank0]:[2024-02-13 09:53:43,995] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1697: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-02-13 09:54:06,085 - root - INFO - step: 1, current loss: 10.54707145690918, lr: [0.0002666666666666667] [rank0]:2024-02-13 09:54:06,153 - root - INFO - step: 2, current loss: 10.481386184692383, lr: [0.0005333333333333334] [rank0]:2024-02-13 09:54:06,222 - root - INFO - step: 3, current loss: 10.334623336791992, lr: [0.0008] [rank0]:2024-02-13 09:54:06,288 - root - INFO - step: 4, current loss: 10.121940612792969, lr: [0.0007] [rank0]:2024-02-13 09:54:06,355 - root - INFO - step: 5, current loss: 9.922933578491211, lr: [0.0006000000000000001] [rank0]:2024-02-13 09:54:06,422 - root - INFO - step: 6, current loss: 9.710294723510742, lr: [0.0005] [rank0]:2024-02-13 09:54:06,487 - root - INFO - step: 7, current loss: 9.587849617004395, lr: [0.0004] [rank0]:2024-02-13 09:54:06,773 - root - INFO - step: 8, current loss: 9.474313735961914, lr: [0.00030000000000000003] [rank0]:STAGE:2024-02-13 09:54:06 3243810:3243810 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank0]:2024-02-13 09:54:06,845 - root - INFO - step: 9, current loss: 9.282522201538086, lr: [0.0002] [rank0]:[rank0]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank0]:STAGE:2024-02-13 09:54:06 3243810:3243810 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank0]:STAGE:2024-02-13 09:54:06 3243810:3243810 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank0]:2024-02-13 09:54:06,999 - root - INFO - exporting profile traces to ./torchtrain/outputs/profiling/traces/iteration_10 [rank0]:2024-02-13 09:54:07,002 - root - INFO - step: 10, current loss: 9.34823989868164, lr: [0.0001] ``` </details>
it's a small thing and can be download from OSS, we can just check in
This PR adds the following: 1 - via reset parameters, a full layerwise init for the llama models under /llama. This uses the total model depth as part of the init via: self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 2 - The final output ffn (head) is init with sqrt of the dim of the model itself and a slightly wider cutoff factor of 3. 3 - tangential change - updates run_llama_train.sh with updated MODEL and MODEL_CONF params to allow for direct model control via the sh script. (there was a MODEL already but it was incorrectly using that in place of MODEL_CONF...though we should update this as it's not intuitive). 4 - made the debugmodel default to 2 layers as an improved debug check. 5 - added a 1B and 40B for additional testing configs. I can't currently run 70B on my H100 due to OOM, but can run 40B. Testing: Verified proper init and training with 7B, 13B and ~40B: <img width="1085" alt="Screenshot 2024-02-11 at 10 39 12 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/049037ed-63a4-4ab0-bebc-f297857aab72">
This PR is the start of adding perf related metrics. 1 - This PR adds function for logging the total num of unique model params, with option for only counting trainable params as well. (for future peft/qlora type work). 2 - logs it with comma formatted logging and model name ala: <img width="716" alt="Screenshot 2024-02-12 at 4 12 22 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/8eb48870-ab1e-4b70-9159-92864ff6c0e5"> this helps de-mistify for example the size of our debug model as well: <img width="716" alt="Screenshot 2024-02-12 at 4 10 17 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/77475306-54bc-48a6-bf28-9c9a542577fd"> **additional updates** - added in gpu mem tracking. We want to show the user peak memory stats, as well as monitor and alert for any cudacachealloc retries which are a perf hindrance. Thus, added class GPUMemoryMonitor: usage: 1 - instantiate <img width="1329" alt="Screenshot 2024-02-13 at 9 32 11 AM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/95610386-6fde-47bb-bbdc-bb7c399c5895"> 2 - start of training = start_monitoring() 3 - end of training = stop_monitoring() 4 - show results = get_peak_stats_str() and rank0_log it. <img width="1074" alt="Screenshot 2024-02-13 at 9 12 45 AM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/b6c7c854-7d83-436a-bea9-a67109422381">
ghstack-source-id: d0828f16c06747a5af2586630e5205bf786de1c4 Pull Request resolved: #57
ghstack-source-id: da7e02b1c2f21a7471ce1dda8bd4d0ee888ad9ac Pull Request resolved: #60
ghstack-source-id: e23d5e0b70abc427a13bc8bf195c876c007f4939 Pull Request resolved: #65
…ix (#63) This PR 1 - adds multi-node training support via a multinode_trainer.slurm file. Verified llama 7b on 20 nodes / 160 A100s. 2 - It also corrects a race condition that can occur on larger scale training in profiling, where the check for the trace dir existence fails for process 1, but in the interim another process 2 makes the directory, and then when process 1 tries to make the dir it errors out as the dir exists. This is a simple fix of adding exist_ok=True to both of the makedir command (dump folder, trace folder). <img width="1047" alt="Screenshot 2024-02-15 at 10 53 18 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/20378637-4adb-425b-91d8-7fd36289d3b5"> <img width="545" alt="Screenshot 2024-02-15 at 10 55 02 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/28658614-cff6-42b5-ab57-bac578393d5c">
…g on slurm (#93) This PR adds the ability to do colored console outputs in order to highlight the training data outputs. It also adds a check to not use this color formatting on slurm, where it will add 33= instead of the color if not avoided. Note that I've just added some color to highlight the main training data. Users that fork/clone can use it to enhance their outputs as desired. <img width="1372" alt="Screenshot 2024-02-26 at 10 20 15 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/44849821-1677-40bf-896c-39344cd661d6"> Note that on slurm it remains plain: <img width="847" alt="Screenshot 2024-02-26 at 10 46 24 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/172eaa58-4f5c-48f5-8ec1-bc349e3e82f2"> if you dont' check this, then it would otherwise look like this (this does not happen with this PR, just showing if we didn't check and credit to Yifu for noting this would be an issue): <img width="847" alt="Screenshot 2024-02-26 at 10 39 23 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/4a87fb9a-dd3a-417c-a29e-286ded069358">
this PR updates the GPU metrics to labelling as GiB - we were calculating GiB but calling it GB. (credit to @awgu for flagging this - issue #94) function names and member vars in metrics.py have been updated to _gib instead of _gb for clarity, and the logging output now labels as GiB: <img width="851" alt="Screenshot 2024-02-27 at 11 28 23 AM" src="https://github.com/pytorch/torchtrain/assets/46302957/85eb260a-77e9-4c49-be8a-b1aaa10dc3e2">
ghstack-source-id: 7dc4a80cf9c32f4dca3d00bcef019d256bdf58f7 Pull Request resolved: #96
Enable libUV for torchtrain. Test: ``` + export USE_LIBUV=1 + USE_LIBUV=1 + TRAINER_DIR=/home/gnadathur/local/torchtrain + NGPU=4 + LOG_RANK=0,1 + CONFIG_FILE=./train_configs/debug_model.toml + torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0,1 --role rank --tee 3 train.py --job.config_file ./train_configs/debug_model.toml W0228 09:12:02.564000 140353616004096 torch/distributed/run.py:717] W0228 09:12:02.564000 140353616004096 torch/distributed/run.py:717] ***************************************** W0228 09:12:02.564000 140353616004096 torch/distributed/run.py:717] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0228 09:12:02.564000 140353616004096 torch/distributed/run.py:717] ***************************************** [rank0]:2024-02-28 09:12:04,581 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [4] [rank1]:2024-02-28 09:12:04,708 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [4] [rank0]:2024-02-28 09:12:05,647 - root - INFO - Building llama [rank0]:2024-02-28 09:12:05,655 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank0]:2024-02-28 09:12:05,655 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank1]:2024-02-28 09:12:07,299 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank1]:2024-02-28 09:12:07,299 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank0]:2024-02-28 09:12:07,565 - root - INFO - Model fully initialized via reset_params [rank0]:2024-02-28 09:12:07,566 - root - INFO - Model built with: ModelArgs(dim=256, n_layers=2, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768, depth_init=True) [rank0]:2024-02-28 09:12:07,566 - root - INFO - �[34mModel llama debugmodel �[31msize: 18,089,216 total parameters�[39m [rank0]:2024-02-28 09:12:07,567 - root - INFO - GPU memory usage: NVIDIA H100 (0): 95.0396 GiB capacity, 0.0 GiB in-use, 0.0% in-use [rank0]:2024-02-28 09:12:08,769 - root - INFO - Applied FSDP to the model... [rank0]:2024-02-28 09:12:08,770 - root - INFO - Gradient scaling not enabled. [rank0]:2024-02-28 09:12:08,770 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./outputs/tb/20240228-0912. [rank0]:2024-02-28 09:12:08,977 - root - INFO - Profiling active. Traces will be saved at ./outputs/profiling/traces [rank0]:2024-02-28 09:12:10,956 - root - INFO - �[36mstep: 1 �[32mloss: 10.9229 �[39miter: �[34m 1.9386�[39m data: �[34m0.0368 �[39mlr: �[33m0.00026667�[39m [rank0]:2024-02-28 09:12:11,045 - root - INFO - �[36mstep: 2 �[32mloss: 10.8673 �[39miter: �[34m 0.0562�[39m data: �[34m0.0316 �[39mlr: �[33m0.00053333�[39m [rank0]:2024-02-28 09:12:11,130 - root - INFO - �[36mstep: 3 �[32mloss: 10.7145 �[39miter: �[34m 0.0523�[39m data: �[34m0.0322 �[39mlr: �[33m0.0008�[39m [rank0]:2024-02-28 09:12:11,219 - root - INFO - �[36mstep: 4 �[32mloss: 10.5038 �[39miter: �[34m 0.0559�[39m data: �[34m0.0319 �[39mlr: �[33m0.0007�[39m [rank0]:2024-02-28 09:12:11,304 - root - INFO - �[36mstep: 5 �[32mloss: 10.2228 �[39miter: �[34m 0.0537�[39m data: �[34m0.031 �[39mlr: �[33m0.0006�[39m [rank0]:2024-02-28 09:12:11,391 - root - INFO - �[36mstep: 6 �[32mloss: 9.9677 �[39miter: �[34m 0.0562�[39m data: �[34m0.0302 �[39mlr: �[33m0.0005�[39m [rank0]:2024-02-28 09:12:11,478 - root - INFO - �[36mstep: 7 �[32mloss: 9.7762 �[39miter: �[34m 0.0544�[39m data: �[34m0.0317 �[39mlr: �[33m0.0004�[39m [rank0]:2024-02-28 09:12:11,676 - root - INFO - �[36mstep: 8 �[32mloss: 9.4359 �[39miter: �[34m 0.0509�[39m data: �[34m0.0322 �[39mlr: �[33m0.0003�[39m [rank1]:STAGE:2024-02-28 09:12:11 3161834:3161834 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank1]:[rank1]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank0]:STAGE:2024-02-28 09:12:11 3161833:3161833 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank0]:2024-02-28 09:12:11,813 - root - INFO - �[36mstep: 9 �[32mloss: 9.2326 �[39miter: �[34m 0.1007�[39m data: �[34m0.0321 �[39mlr: �[33m0.0002�[39m [rank0]:[rank0]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank1]:STAGE:2024-02-28 09:12:11 3161834:3161834 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank1]:STAGE:2024-02-28 09:12:11 3161834:3161834 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank0]:STAGE:2024-02-28 09:12:11 3161833:3161833 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank0]:STAGE:2024-02-28 09:12:11 3161833:3161833 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank0]:2024-02-28 09:12:12,195 - root - INFO - exporting profile traces to ./outputs/profiling/traces/iteration_10 [rank0]:2024-02-28 09:12:12,207 - root - INFO - �[36mstep: 10 �[32mloss: 9.1641 �[39miter: �[34m 0.0971�[39m data: �[34m0.031 �[39mlr: �[33m0.0001�[39m [rank0]:2024-02-28 09:12:12,207 - root - INFO - Average iter time: 0.0670 seconds [rank0]:2024-02-28 09:12:12,207 - root - INFO - Average data load time: 0.0314 seconds [rank0]:2024-02-28 09:12:12,208 - root - INFO - Current Memory: NVIDIA H100 (0): Reserved: 9.6465%, Alloc 2.1969%, Active: 2.2% [rank0]:Peak Memory: Reserved 9.65%, Alloc 8.43%, Active: 8.44% [rank0]:num retries: 0, num ooms: 0 [rank0]:NCCL version 2.19.3+cuda12.0 ``` --------- Co-authored-by: gnadathur <[email protected]>
as titled, we don't want to allow steps == -1 case as it would blow up the lr scheduler
Add 7b config and adjust options to be more realistic didn't add this to the train scripts as default as it's expensive to init, whoever use it can adjust it accordingly
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40 Pull Request resolved: #97
Summary: Adding a description field, useful for integration tests to describe the test. Test Plan: ``` + export USE_LIBUV=1 + USE_LIBUV=1 + TRAINER_DIR=/home/gnadathur/local/torchtrain + NGPU=4 + LOG_RANK=0,1 + CONFIG_FILE=./train_configs/debug_model.toml + torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0,1 --role rank --tee 3 train.py --job.config_file ./train_configs/debug_model.toml W0229 17:05:02.466000 140187679912960 torch/distributed/run.py:717] W0229 17:05:02.466000 140187679912960 torch/distributed/run.py:717] ***************************************** W0229 17:05:02.466000 140187679912960 torch/distributed/run.py:717] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0229 17:05:02.466000 140187679912960 torch/distributed/run.py:717] ***************************************** [rank1]:2024-02-29 17:05:04,269 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [4] [rank0]:2024-02-29 17:05:04,510 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [4] [rank0]:2024-02-29 17:05:05,327 - root - INFO - Starting job: debug training [rank0]:2024-02-29 17:05:05,327 - root - INFO - Building llama [rank0]:2024-02-29 17:05:05,335 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank0]:2024-02-29 17:05:05,335 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank1]:2024-02-29 17:05:06,782 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank1]:2024-02-29 17:05:06,782 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank0]:2024-02-29 17:05:07,347 - root - INFO - Model fully initialized via reset_params [rank0]:2024-02-29 17:05:07,349 - root - INFO - Model built with: ModelArgs(dim=256, n_layers=2, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768, depth_init=True) [rank0]:2024-02-29 17:05:07,349 - root - INFO - �[34mModel llama debugmodel �[31msize: 18,089,216 total parameters�[39m [rank0]:2024-02-29 17:05:07,349 - root - INFO - GPU memory usage: NVIDIA H100 (0): 95.0396 GiB capacity, 0.0 GiB in-use, 0.0% in-use [rank0]:2024-02-29 17:05:08,375 - root - INFO - Applied FSDP to the model... [rank0]:2024-02-29 17:05:08,376 - root - INFO - Gradient scaling not enabled. [rank0]:2024-02-29 17:05:08,376 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./outputs/tb/20240229-1705. [rank0]:2024-02-29 17:05:08,610 - root - INFO - Profiling active. Traces will be saved at ./outputs/profiling/traces [rank0]:2024-02-29 17:05:10,570 - root - INFO - �[36mstep: 1 �[32mloss: 10.9183 �[39miter: �[34m 1.9258�[39m data: �[34m0.0303 �[39mlr: �[33m0.00026667�[39m [rank0]:2024-02-29 17:05:10,653 - root - INFO - �[36mstep: 2 �[32mloss: 10.8347 �[39miter: �[34m 0.0487�[39m data: �[34m0.0336 �[39mlr: �[33m0.00053333�[39m [rank0]:2024-02-29 17:05:10,733 - root - INFO - �[36mstep: 3 �[32mloss: 10.6861 �[39miter: �[34m 0.045�[39m data: �[34m0.0334 �[39mlr: �[33m0.0008�[39m [rank0]:2024-02-29 17:05:10,812 - root - INFO - �[36mstep: 4 �[32mloss: 10.4672 �[39miter: �[34m 0.0453�[39m data: �[34m0.0336 �[39mlr: �[33m0.0007�[39m [rank0]:2024-02-29 17:05:10,893 - root - INFO - �[36mstep: 5 �[32mloss: 10.2154 �[39miter: �[34m 0.0466�[39m data: �[34m0.033 �[39mlr: �[33m0.0006�[39m [rank0]:2024-02-29 17:05:10,975 - root - INFO - �[36mstep: 6 �[32mloss: 9.9573 �[39miter: �[34m 0.0496�[39m data: �[34m0.0314 �[39mlr: �[33m0.0005�[39m [rank0]:2024-02-29 17:05:11,056 - root - INFO - �[36mstep: 7 �[32mloss: 9.7627 �[39miter: �[34m 0.0486�[39m data: �[34m0.0321 �[39mlr: �[33m0.0004�[39m [rank0]:2024-02-29 17:05:11,201 - root - INFO - �[36mstep: 8 �[32mloss: 9.437 �[39miter: �[34m 0.0457�[39m data: �[34m0.0333 �[39mlr: �[33m0.0003�[39m [rank1]:STAGE:2024-02-29 17:05:11 3368103:3368103 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank1]:[rank1]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank0]:STAGE:2024-02-29 17:05:11 3368102:3368102 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank0]:2024-02-29 17:05:11,317 - root - INFO - �[36mstep: 9 �[32mloss: 9.2446 �[39miter: �[34m 0.0794�[39m data: �[34m0.0324 �[39mlr: �[33m0.0002�[39m [rank0]:[rank0]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank1]:STAGE:2024-02-29 17:05:11 3368103:3368103 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank1]:STAGE:2024-02-29 17:05:11 3368103:3368103 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank0]:STAGE:2024-02-29 17:05:11 3368102:3368102 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank0]:STAGE:2024-02-29 17:05:11 3368102:3368102 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank0]:2024-02-29 17:05:11,748 - root - INFO - exporting profile traces to ./outputs/profiling/traces/iteration_10 [rank0]:2024-02-29 17:05:11,762 - root - INFO - �[36mstep: 10 �[32mloss: 9.1772 �[39miter: �[34m 0.0893�[39m data: �[34m0.0324 �[39mlr: �[33m0.0001�[39m [rank0]:2024-02-29 17:05:11,763 - root - INFO - Average iter time: 0.0578 seconds [rank0]:2024-02-29 17:05:11,763 - root - INFO - Average data load time: 0.0326 seconds [rank0]:2024-02-29 17:05:11,763 - root - INFO - Current Memory: NVIDIA H100 (0): Reserved: 9.6465%, Alloc 2.1969%, Active: 2.2% [rank0]:Peak Memory: Reserved 9.65%, Alloc 8.43%, Active: 8.44% [rank0]:num retries: 0, num ooms: 0 [rank0]:NCCL version 2.19.3+cuda12.0 ``` Reviewers: Subscribers: Tasks: Tags: Co-authored-by: gnadathur <[email protected]>
ghstack-source-id: 1c5bf790d7473f6a24124051fcfa1fd2585a56f9 Pull Request resolved: #105
``` + export USE_LIBUV=1 + USE_LIBUV=1 + TRAINER_DIR=/home/gnadathur/local/torchtrain + NGPU=4 + LOG_RANK=0,1 + CONFIG_FILE=./train_configs/debug_model.toml + torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0,1 --role rank --tee 3 train.py --job.config_file ./train_configs/debug_model.toml W0304 17:01:26.766000 140549371597824 torch/distributed/run.py:717] W0304 17:01:26.766000 140549371597824 torch/distributed/run.py:717] ***************************************** W0304 17:01:26.766000 140549371597824 torch/distributed/run.py:717] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0304 17:01:26.766000 140549371597824 torch/distributed/run.py:717] ***************************************** [rank0]:2024-03-04 17:01:28,834 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [4] [rank1]:2024-03-04 17:01:28,857 - torchtrain.parallelisms - INFO - Building 1-D device mesh with ('dp',), [4] [rank0]:2024-03-04 17:01:29,712 - root - INFO - Starting job: debug training [rank0]:2024-03-04 17:01:29,712 - root - INFO - Building llama [rank0]:2024-03-04 17:01:29,719 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank0]:2024-03-04 17:01:29,719 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank1]:2024-03-04 17:01:31,187 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank1]:2024-03-04 17:01:31,188 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank0]:2024-03-04 17:01:31,346 - root - INFO - Model fully initialized via reset_params [rank0]:2024-03-04 17:01:31,346 - root - INFO - Model built with: ModelArgs(dim=256, n_layers=2, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768, depth_init=True) [rank0]:2024-03-04 17:01:31,347 - root - INFO - �[34mModel llama debugmodel �[31msize: 18,089,216 total parameters�[39m [rank0]:2024-03-04 17:01:31,347 - root - INFO - GPU memory usage: NVIDIA H100 (0): 95.0396 GiB capacity, 0.0 GiB in-use, 0.0% in-use [rank0]:2024-03-04 17:01:32,502 - root - INFO - Applied FSDP to the model... [rank0]:2024-03-04 17:01:32,503 - root - INFO - Gradient scaling not enabled. [rank0]:2024-03-04 17:01:32,504 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./outputs/tb/20240304-1701. [rank0]:2024-03-04 17:01:32,901 - root - INFO - Profiling active. Traces will be saved at ./outputs/profiling/traces [rank0]:2024-03-04 17:01:34,806 - root - INFO - �[36mstep: 1 �[32mloss: 10.8424 �[39miter: �[34m 1.8688�[39m data: �[34m0.0316 �[39mlr: �[33m0.00026667�[39m [rank0]:2024-03-04 17:01:34,891 - root - INFO - �[36mstep: 2 �[32mloss: 10.7581 �[39miter: �[34m 0.0476�[39m data: �[34m0.0357 �[39mlr: �[33m0.00053333�[39m [rank0]:2024-03-04 17:01:34,970 - root - INFO - �[36mstep: 3 �[32mloss: 10.6239 �[39miter: �[34m 0.045�[39m data: �[34m0.0333 �[39mlr: �[33m0.0008�[39m [rank0]:2024-03-04 17:01:35,048 - root - INFO - �[36mstep: 4 �[32mloss: 10.4163 �[39miter: �[34m 0.0455�[39m data: �[34m0.0323 �[39mlr: �[33m0.0007�[39m [rank0]:2024-03-04 17:01:35,127 - root - INFO - �[36mstep: 5 �[32mloss: 10.1529 �[39miter: �[34m 0.0459�[39m data: �[34m0.032 �[39mlr: �[33m0.0006�[39m [rank0]:2024-03-04 17:01:35,206 - root - INFO - �[36mstep: 6 �[32mloss: 9.8899 �[39miter: �[34m 0.0468�[39m data: �[34m0.0311 �[39mlr: �[33m0.0005�[39m [rank0]:2024-03-04 17:01:35,284 - root - INFO - �[36mstep: 7 �[32mloss: 9.7204 �[39miter: �[34m 0.0461�[39m data: �[34m0.0312 �[39mlr: �[33m0.0004�[39m [rank0]:2024-03-04 17:01:35,425 - root - INFO - �[36mstep: 8 �[32mloss: 9.3757 �[39miter: �[34m 0.0457�[39m data: �[34m0.0319 �[39mlr: �[33m0.0003�[39m [rank0]:STAGE:2024-03-04 17:01:35 3850444:3850444 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank0]:2024-03-04 17:01:35,537 - root - INFO - �[36mstep: 9 �[32mloss: 9.1883 �[39miter: �[34m 0.0762�[39m data: �[34m0.0318 �[39mlr: �[33m0.0002�[39m [rank0]:[rank0]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank1]:STAGE:2024-03-04 17:01:35 3850445:3850445 ActivityProfilerController.cpp:314] Completed Stage: Warm Up [rank1]:[rank1]:[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event [rank0]:STAGE:2024-03-04 17:01:35 3850444:3850444 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank0]:STAGE:2024-03-04 17:01:35 3850444:3850444 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank1]:STAGE:2024-03-04 17:01:35 3850445:3850445 ActivityProfilerController.cpp:320] Completed Stage: Collection [rank1]:STAGE:2024-03-04 17:01:35 3850445:3850445 ActivityProfilerController.cpp:324] Completed Stage: Post Processing [rank0]:2024-03-04 17:01:35,958 - root - INFO - exporting profile traces to ./outputs/profiling/traces/iteration_10 [rank0]:2024-03-04 17:01:35,971 - root - INFO - �[36mstep: 10 �[32mloss: 9.1212 �[39miter: �[34m 0.0808�[39m data: �[34m0.0319 �[39mlr: �[33m0.0001�[39m [rank0]:2024-03-04 17:01:35,972 - root - INFO - Average iter time: 0.0553 seconds [rank0]:2024-03-04 17:01:35,972 - root - INFO - Average data load time: 0.0317 seconds [rank0]:2024-03-04 17:01:35,972 - root - INFO - Current Memory: NVIDIA H100 (0): Reserved: 9.6465%, Alloc 2.1969%, Active: 2.2% [rank0]:Peak Memory: Reserved 9.65%, Alloc 8.43%, Active: 8.44% [rank0]:num retries: 0, num ooms: 0 [rank0]:NCCL version 2.19.3+cuda12.0 ``` Co-authored-by: gnadathur <[email protected]>
This PR enables meta_init functionality to avoid OOM'ing on cpu for larger models. The core functionality is in meta_init.py, and a few changes in parallelization and train.py. Key items: 1 - this is largely the same as the earlier PR I had for meta_init, but I did a new one b/c faster than reworking it with all the interim changes. 2 - to address feedback in previous PR: a - why do we need meta_init.py, can't we just do: ~~~ with torch.device("meta"): model = Model.from_args(...) ~~~ Unfortunately this does not work b/c the rope embeddings are treated differently (buffer) and thus the simple lambda call from param_init_fn in FSDP (lambda module: module.to_device('cuda') ) will not invoke or move the rope embeddings and the model will fail on first forward. This issue relates to the nn.embeddings not being moved, and that the device is referenced in the forward pass for the current rope class. Have opened #110 to track this and investigate while not holding up meta init that is working from landing. b - per earlier feedback - meta init is now 'not optional' but simply the default. This should ensure all models leverage it and ensure we aren't missing things for future meta_init aspects. 3 - misc change - I switched the model_params to just do the normal all params count instead of 'unique params' b/c it does not mesh with what people perceive model size as. Testing: tested both debugmodel and 26B model with and without meta init to confirm same loss curves. Note for future reference - if you get a bad init (meta init failure) you will simply not train (loss is same every iter). If you fail to call reset params after FSDP, then you will train (b/c we default to torch.randn_like) but your starting loss will be 5x+ higher (telling you that you have not properly init'ed the model).
Co-authored-by: gnadathur <[email protected]>
ghstack-source-id: 5133a8d97ad209b569e0fc528e58daafdd31d80d Pull Request resolved: #114
ghstack-source-id: a0c8b4454f75ad1cd9824ac89a1df0182f6a7d8c Pull Request resolved: #112
…data' at 40 iters issue) (#88) This PR add's minipile (1M, 6GB) dataset as an option for pretraining with torchtrain. It resolves the issue where we run out of data after 40 iterations with the default alpaca dataset. Per @tianyu-l's excellent suggestion, have refactored to have a single hf_datasets.py file that supports both minipile and alpaca since it turned out no need for any different tokenizer, etc. Also cleaned up the datasets package so that create_tokenizer is exposed directly, and thus all public apis can be used directly from torchtrain.datasets. Lastly - added warning if/when a dataset is being re-looped so users don't get burned by overfitting: <img width="1294" alt="Screenshot 2024-03-06 at 5 11 09 AM" src="https://github.com/pytorch/torchtrain/assets/46302957/82480b6f-c677-4794-80c5-5c10b037732a"> Adds a color highlight to showcase what dataloader was built: <img width="1360" alt="Screenshot 2024-03-05 at 9 19 10 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/4717ec6a-14bb-4283-a3ae-fa40c27deee0"> and <img width="1360" alt="Screenshot 2024-03-05 at 9 22 01 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/dbf32d51-2dd4-4526-8855-9b33b627559e"> Usage: just add "minipile" or "alpaca" as the dataset in the training config toml file. <img width="439" alt="Screenshot 2024-02-25 at 12 35 26 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/1afbaed1-07f8-4e37-b8cc-80190db7fb27"> Testing: verified training loss is improving and ran for 100 iters to verify no issue with out of data any longer with minipile. reran with alpaca and saw the expected out of data at 40 iters without infinite loop option, runs to 100 with infinite. Notes: I did not make this a default dataset since for debugmodel, mostly running 10 iters is fine and there's 6GB to pull down. <img width="869" alt="Screenshot 2024-02-25 at 12 30 29 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/1070a80a-ad20-4f0f-a860-e13caa3120a0">
ghstack-source-id: 3c930054d3b04faf3866048740a2ef887d066dd6 Pull Request resolved: #117
ghstack-source-id: 733bf85716cda3a5b9af780eba79c9b5dd66abad Pull Request resolved: #121
ghstack-source-id: d7cd26d84aa2442ac45223992e1766397e52c8d8 Pull Request resolved: #122
according to suggestions in #118 (comment) ghstack-source-id: 357f0872cd1c9bad2c4c256d47adbd3f716a7651 Pull Request resolved: #123
…t job configs (#124) This PR: 1 - adds the english language portion of c4 dataset, which has 177M entries. (https://huggingface.co/datasets/allenai/c4) Due to the size, streaming is enabled as the default. This is the allen-ai/c4, as apparently the original c4 is being deprecated and HF advises to use allen-ai now. For comparison per @tianyu-l request - 40 iterations avg time: alpaca cached loading: Average data load time: 0.0279 seconds c4 streaming loading: Average data load time: 0.0290 seconds There is a longer initial delay during the 'preparing c4' vs alpaca (i.e. 45 seconds vs 10 seconds), but after that speed is similar. Dataset sample (not displayed in training, just an excerpt I pulled to double check the data flow): <img width="1233" alt="Screenshot 2024-03-08 at 5 31 06 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/94915f80-da70-48d1-8c43-43f874fef121"> 2 - I also updated the multi-node slurm file to account for the new job config. Test: verified no looping with 100 iterations, sampled data streamed to verify.
This PR uses shared memory to do async checkpoint on another process and also implements async staging (overlapping staging with the next iteration).
facebook-github-bot
added
the
CLA Signed
This label is managed by the Meta Open Source bot.
label
Mar 12, 2024
This PR uses shared memory to do async checkpoint on another process and also implements async staging (overlapping staging with the next iteration).
fegin
commented
Mar 12, 2024
@@ -241,6 +247,10 @@ def main(job_config: JobConfig): | |||
# backward on scaled loss to create scaled gradients | |||
scaler.scale(loss).backward() | |||
|
|||
# This is a simple if statement if checkpoint does not happen and doesn't | |||
# affect the performance. | |||
checkpoint.wait_staging() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@autodelete here is the point to synchronize with the background staging.
autodelete
reviewed
Mar 13, 2024
rank0_log( | ||
f"Sending the state dict to the background process, {time.monotonic()}." | ||
) | ||
self.mp_queue_send.put((state_dict, checkpoint_id)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/state_dict/self.cpu_offload_state_dict/ here?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR uses shared memory to do async checkpoint on another process and also implements async staging (overlapping staging with the next iteration).