Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fused CEL #457

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from
Draft

[WIP] Fused CEL #457

wants to merge 48 commits into from

Conversation

jeromeku
Copy link

@jeromeku jeromeku commented May 13, 2024

Efficient Fused Cross Entropy Loss

Memory-efficient cross entropy implementation that only materializes the derivatives of the language modeling head layer without storing the logits and chunks the computation of the logits such that the full logits tensor is never realized.

This is a direct adaptation of this repo.

Contents

Overview

In short:

  • the logits, derivative with respect to the hidden state inputs to the language modeling head layer (dX hereafter), and the derivative with respect to the logits projection weights (dW hereafter) are computed in chunks
  • the logits are overwritten by its derivatives within a custom loss kernel to avoid additional memory allocations.

See the original repo for an excellent explanation of the design.

Changes

The following changes were made to the original kernel:

  • Reshape inputs and labels to adapt the 3-D language modeling tensors with the required shapes of the kernel.
  • Upcast loss to float32, which in the original kernel was initialized to the autocasted / in-feat dtype.
  • Add torch.cuda.amp.{custom_fwd,custom_bwd} to the autograd.Function.

All changes are enumerated in unsloth/kernels/fused_cel.py.

Additionally, adapter layers and configs in fused_cel.py enable integration with transformers and unsloth.

Tests

See tests/test_CEL.py for correctness checks.

The comments in the tests describe numerical edge cases.

Benchmarks

Following are results from preliminary benchmarking / testing on a L4 NVIDIA GPU for a small llama-like model with and without the fused CEL layer.

The takeaway is that the memory efficiency claims of the original repo are evident, with overall memory usage lower, decreasing linearly with the number of loop iterations.

Can be reproduced by passing the provided options to benchmark_hf_test_cel.py (run with --help to see all options).

Below is the overall config, followed by training losses / grad norms and overall training metrics for float32 and bfloat16.

Test config:

  • max_steps=50
  • model_id=hf-internal-testing/tiny-random-LlamaForCausalLM
  • batch_size=2
  • max_seq_len=256
  • packing=True
  • grad_accum_steps=1
  • load_in_4bit=False
  • use_lora=False
  • fused_cel_n_loop_iters=[1, 2, 4]

float32

  • n_loop_it=1
loss grad_norm
fused_cel no-fused absdiff fused_cel no-fused absdiff
1 10.369300 10.369300 0.000000 0.375981 0.375981 0.000000
2 10.383600 10.383600 0.000000 0.409343 0.409344 0.000000
3 10.374800 10.374800 0.000000 0.411205 0.411205 0.000000
4 10.380000 10.380000 0.000000 0.337345 0.337345 0.000000
5 10.376800 10.376800 0.000000 0.354001 0.354001 0.000000
6 10.363800 10.363800 0.000000 0.457850 0.457851 0.000000
7 10.379100 10.379100 0.000000 0.327099 0.327099 0.000000
8 10.372200 10.372200 0.000000 0.324939 0.324939 0.000000
9 10.360500 10.360500 0.000000 0.463365 0.463365 0.000000
10 10.369700 10.369700 0.000000 0.345713 0.345714 0.000000
11 10.377000 10.377000 0.000000 0.323786 0.323786 0.000000
12 10.363000 10.363000 0.000000 0.366833 0.366833 0.000000
13 10.358700 10.358700 0.000000 0.386118 0.386118 0.000000
14 10.362500 10.362500 0.000000 0.345925 0.345925 0.000000
15 10.368100 10.368100 0.000000 0.339570 0.339571 0.000000
16 10.360500 10.360500 0.000000 0.382450 0.382450 0.000000
17 10.367800 10.367800 0.000000 0.328462 0.328463 0.000000
18 10.362700 10.362700 0.000000 0.567761 0.567761 0.000000
19 10.359300 10.359300 0.000000 0.344158 0.344158 0.000000
20 10.363500 10.363500 0.000000 0.337636 0.337636 0.000000
21 10.352300 10.352300 0.000000 0.382984 0.382984 0.000000
22 10.364700 10.364700 0.000000 0.330023 0.330023 0.000000
23 10.365200 10.365200 0.000000 0.366450 0.366450 0.000000
24 10.351900 10.351900 0.000000 0.366239 0.366240 0.000000
25 10.345900 10.345900 0.000000 0.454505 0.454506 0.000000
26 10.353900 10.353900 0.000000 0.372731 0.372731 0.000000
27 10.351000 10.351000 0.000000 0.386128 0.386128 0.000000
28 10.362900 10.362900 0.000000 0.362428 0.362428 0.000000
29 10.356200 10.356200 0.000000 0.362041 0.362041 0.000000
30 10.361400 10.361400 0.000000 0.345147 0.345147 0.000000
31 10.357700 10.357700 0.000000 0.353345 0.353345 0.000000
32 10.358000 10.358000 0.000000 0.338220 0.338219 0.000001
33 10.357200 10.357200 0.000000 0.346525 0.346525 0.000000
34 10.338500 10.338500 0.000000 0.429826 0.429826 0.000001
35 10.338200 10.338200 0.000000 0.410369 0.410370 0.000000
36 10.362200 10.362200 0.000000 0.308196 0.308197 0.000001
37 10.338700 10.338700 0.000000 0.406986 0.406987 0.000001
38 10.355800 10.355800 0.000000 0.347940 0.347942 0.000002
39 10.337200 10.337200 0.000000 0.484625 0.484626 0.000001
40 10.355100 10.355100 0.000000 0.419877 0.419879 0.000002
41 10.357300 10.357300 0.000000 0.355641 0.355643 0.000001
42 10.361700 10.361700 0.000000 0.338817 0.338817 0.000001
43 10.327000 10.327000 0.000000 0.466670 0.466672 0.000001
44 10.351100 10.351100 0.000000 0.365030 0.365031 0.000001
45 10.360800 10.360800 0.000000 0.347445 0.347447 0.000001
46 10.315900 10.315900 0.000000 0.495173 0.495069 0.000104
47 10.345500 10.345500 0.000000 0.373585 0.373586 0.000001
48 10.339500 10.339500 0.000000 0.367941 0.367942 0.000001
49 10.318600 10.318600 0.000000 0.495867 0.495869 0.000001
50 10.368600 10.368600 0.000000 0.427715 0.427713 0.000001
  • n_loop_it=2
loss grad_norm
fused_cel no-fused absdiff fused_cel no-fused absdiff
1 10.369300 10.369300 0.000000 0.375981 0.375981 0.000000
2 10.383600 10.383600 0.000000 0.409343 0.409344 0.000000
3 10.374800 10.374800 0.000000 0.411205 0.411205 0.000000
4 10.380000 10.380000 0.000000 0.337345 0.337345 0.000000
5 10.376800 10.376800 0.000000 0.354001 0.354001 0.000000
6 10.363800 10.363800 0.000000 0.457850 0.457851 0.000000
7 10.379100 10.379100 0.000000 0.327099 0.327099 0.000000
8 10.372200 10.372200 0.000000 0.324939 0.324939 0.000000
9 10.360500 10.360500 0.000000 0.463365 0.463365 0.000000
10 10.369700 10.369700 0.000000 0.345713 0.345714 0.000000
11 10.377000 10.377000 0.000000 0.323786 0.323786 0.000000
12 10.363000 10.363000 0.000000 0.366833 0.366833 0.000000
13 10.358700 10.358700 0.000000 0.386118 0.386118 0.000000
14 10.362500 10.362500 0.000000 0.345925 0.345925 0.000000
15 10.368100 10.368100 0.000000 0.339570 0.339571 0.000000
16 10.360500 10.360500 0.000000 0.382450 0.382450 0.000000
17 10.367800 10.367800 0.000000 0.328462 0.328463 0.000000
18 10.362700 10.362700 0.000000 0.567761 0.567761 0.000000
19 10.359300 10.359300 0.000000 0.344158 0.344158 0.000000
20 10.363500 10.363500 0.000000 0.337636 0.337636 0.000001
21 10.352300 10.352300 0.000000 0.382984 0.382984 0.000000
22 10.364700 10.364700 0.000000 0.330023 0.330023 0.000000
23 10.365200 10.365200 0.000000 0.366450 0.366450 0.000000
24 10.351900 10.351900 0.000000 0.366239 0.366240 0.000000
25 10.345900 10.345900 0.000000 0.454505 0.454506 0.000000
26 10.353900 10.353900 0.000000 0.372731 0.372731 0.000000
27 10.351000 10.351000 0.000000 0.386128 0.386128 0.000000
28 10.362900 10.362900 0.000000 0.362428 0.362428 0.000000
29 10.356200 10.356200 0.000000 0.362041 0.362041 0.000000
30 10.361400 10.361400 0.000000 0.345147 0.345147 0.000000
31 10.357700 10.357700 0.000000 0.353345 0.353345 0.000000
32 10.358000 10.358000 0.000000 0.338220 0.338219 0.000001
33 10.357200 10.357200 0.000000 0.346525 0.346525 0.000000
34 10.338500 10.338500 0.000000 0.429826 0.429826 0.000000
35 10.338200 10.338200 0.000000 0.410370 0.410370 0.000000
36 10.362200 10.362200 0.000000 0.308196 0.308197 0.000000
37 10.338700 10.338700 0.000000 0.406987 0.406987 0.000000
38 10.355800 10.355800 0.000000 0.347942 0.347942 0.000000
39 10.337200 10.337200 0.000000 0.484625 0.484626 0.000000
40 10.355100 10.355100 0.000000 0.419878 0.419879 0.000000
41 10.357300 10.357300 0.000000 0.355642 0.355643 0.000001
42 10.361700 10.361700 0.000000 0.338817 0.338817 0.000000
43 10.327000 10.327000 0.000000 0.466671 0.466672 0.000000
44 10.351100 10.351100 0.000000 0.365031 0.365031 0.000000
45 10.360800 10.360800 0.000000 0.347446 0.347447 0.000001
46 10.315900 10.315900 0.000000 0.495084 0.495069 0.000015
47 10.345500 10.345500 0.000000 0.373585 0.373586 0.000001
48 10.339500 10.339500 0.000000 0.367942 0.367942 0.000000
49 10.318600 10.318600 0.000000 0.495868 0.495869 0.000000
50 10.368600 10.368600 0.000000 0.427714 0.427713 0.000001
  • n_loop_it=4
loss grad_norm
fused_cel no-fused absdiff fused_cel no-fused absdiff
1 10.369300 10.369300 0.000000 0.375981 0.375981 0.000000
2 10.383600 10.383600 0.000000 0.409343 0.409344 0.000000
3 10.374800 10.374800 0.000000 0.411205 0.411205 0.000000
4 10.380000 10.380000 0.000000 0.337345 0.337345 0.000000
5 10.376800 10.376800 0.000000 0.354001 0.354001 0.000000
6 10.363800 10.363800 0.000000 0.457850 0.457851 0.000000
7 10.379100 10.379100 0.000000 0.327099 0.327099 0.000000
8 10.372200 10.372200 0.000000 0.324939 0.324939 0.000000
9 10.360500 10.360500 0.000000 0.463365 0.463365 0.000000
10 10.369700 10.369700 0.000000 0.345713 0.345714 0.000000
11 10.377000 10.377000 0.000000 0.323786 0.323786 0.000000
12 10.363000 10.363000 0.000000 0.366833 0.366833 0.000000
13 10.358700 10.358700 0.000000 0.386118 0.386118 0.000000
14 10.362500 10.362500 0.000000 0.345925 0.345925 0.000000
15 10.368100 10.368100 0.000000 0.339570 0.339571 0.000000
16 10.360500 10.360500 0.000000 0.382450 0.382450 0.000000
17 10.367800 10.367800 0.000000 0.328462 0.328463 0.000000
18 10.362700 10.362700 0.000000 0.567761 0.567761 0.000000
19 10.359300 10.359300 0.000000 0.344158 0.344158 0.000000
20 10.363500 10.363500 0.000000 0.337636 0.337636 0.000001
21 10.352300 10.352300 0.000000 0.382984 0.382984 0.000000
22 10.364700 10.364700 0.000000 0.330023 0.330023 0.000000
23 10.365200 10.365200 0.000000 0.366450 0.366450 0.000000
24 10.351900 10.351900 0.000000 0.366239 0.366240 0.000000
25 10.345900 10.345900 0.000000 0.454506 0.454506 0.000000
26 10.353900 10.353900 0.000000 0.372731 0.372731 0.000000
27 10.351000 10.351000 0.000000 0.386128 0.386128 0.000000
28 10.362900 10.362900 0.000000 0.362428 0.362428 0.000000
29 10.356200 10.356200 0.000000 0.362041 0.362041 0.000000
30 10.361400 10.361400 0.000000 0.345147 0.345147 0.000000
31 10.357700 10.357700 0.000000 0.353345 0.353345 0.000000
32 10.358000 10.358000 0.000000 0.338220 0.338219 0.000001
33 10.357200 10.357200 0.000000 0.346525 0.346525 0.000000
34 10.338500 10.338500 0.000000 0.429826 0.429826 0.000000
35 10.338200 10.338200 0.000000 0.410370 0.410370 0.000001
36 10.362200 10.362200 0.000000 0.308197 0.308197 0.000000
37 10.338700 10.338700 0.000000 0.406987 0.406987 0.000000
38 10.355800 10.355800 0.000000 0.347942 0.347942 0.000000
39 10.337200 10.337200 0.000000 0.484626 0.484626 0.000001
40 10.355100 10.355100 0.000000 0.419879 0.419879 0.000000
41 10.357300 10.357300 0.000000 0.355643 0.355643 0.000000
42 10.361700 10.361700 0.000000 0.338818 0.338817 0.000000
43 10.327000 10.327000 0.000000 0.466672 0.466672 0.000000
44 10.351100 10.351100 0.000000 0.365031 0.365031 0.000000
45 10.360800 10.360800 0.000000 0.347446 0.347447 0.000001
46 10.315900 10.315900 0.000000 0.495063 0.495069 0.000006
47 10.345500 10.345500 0.000000 0.373586 0.373586 0.000000
48 10.339500 10.339500 0.000000 0.367942 0.367942 0.000000
49 10.318600 10.318600 0.000000 0.495869 0.495869 0.000000
50 10.368600 10.368600 0.000000 0.427715 0.427713 0.000001

Training metrics for float32:

step trainable_params total_params n_loop_iters total_flos train_loss train_mem_gpu_peaked_delta train_samples_per_second train_steps_per_second train_runtime
no-fused 50 1032272 1032272 1 74GF 10.3577 188MB 27.031 13.516 0:00:03.69
fused_cel 50 1032272 1032272 1 74GF 10.3577 66MB 27.321 13.66 0:00:03.66
fused_cel 50 1032272 1032272 2 74GF 10.3577 35MB 34.413 17.207 0:00:02.90
fused_cel 50 1032272 1032272 4 74GF 10.3577 19MB 34.124 17.062 0:00:02.93

bfloat16

  • n_loop_it=1
loss grad_norm
fused_cel no-fused absdiff fused_cel no-fused absdiff
1 10.369300 10.369300 0.000000 0.375000 0.375000 0.000000
2 10.383600 10.383600 0.000000 0.408203 0.408203 0.000000
3 10.374700 10.374800 0.000100 0.408203 0.408203 0.000000
4 10.379900 10.379900 0.000000 0.335938 0.335938 0.000000
5 10.376600 10.376600 0.000000 0.353516 0.353516 0.000000
6 10.363300 10.363300 0.000000 0.457031 0.457031 0.000000
7 10.378900 10.378900 0.000000 0.326172 0.326172 0.000000
8 10.372000 10.372000 0.000000 0.324219 0.324219 0.000000
9 10.360000 10.360000 0.000000 0.460938 0.460938 0.000000
10 10.369300 10.369300 0.000000 0.343750 0.343750 0.000000
11 10.377000 10.377000 0.000000 0.322266 0.322266 0.000000
12 10.362600 10.362600 0.000000 0.365234 0.365234 0.000000
13 10.358700 10.358700 0.000000 0.384766 0.384766 0.000000
14 10.362900 10.362900 0.000000 0.345703 0.345703 0.000000
15 10.368100 10.368100 0.000000 0.337891 0.337891 0.000000
16 10.360100 10.360100 0.000000 0.378906 0.378906 0.000000
17 10.367600 10.367700 0.000100 0.326172 0.326172 0.000000
18 10.362000 10.362100 0.000100 0.566406 0.566406 0.000000
19 10.359200 10.359100 0.000100 0.345703 0.345703 0.000000
20 10.362900 10.362900 0.000000 0.335938 0.335938 0.000000
21 10.352200 10.352300 0.000100 0.380859 0.380859 0.000000
22 10.365100 10.365000 0.000100 0.330078 0.330078 0.000000
23 10.365000 10.365000 0.000000 0.363281 0.363281 0.000000
24 10.352400 10.352500 0.000100 0.365234 0.365234 0.000000
25 10.346100 10.346100 0.000000 0.451172 0.451172 0.000000
26 10.353900 10.353800 0.000100 0.371094 0.371094 0.000000
27 10.350900 10.350800 0.000100 0.384766 0.384766 0.000000
28 10.363000 10.363300 0.000300 0.359375 0.359375 0.000000
29 10.355400 10.355300 0.000100 0.361328 0.361328 0.000000
30 10.361300 10.360500 0.000800 0.341797 0.341797 0.000000
31 10.358800 10.358900 0.000100 0.351562 0.349609 0.001953
32 10.358800 10.358900 0.000100 0.333984 0.333984 0.000000
33 10.358200 10.358300 0.000100 0.343750 0.343750 0.000000
34 10.339200 10.339300 0.000100 0.425781 0.425781 0.000000
35 10.339200 10.339200 0.000000 0.408203 0.408203 0.000000
36 10.364000 10.364000 0.000000 0.304688 0.304688 0.000000
37 10.340300 10.340100 0.000200 0.402344 0.402344 0.000000
38 10.356800 10.356700 0.000100 0.343750 0.345703 0.001953
39 10.338900 10.339200 0.000300 0.478516 0.478516 0.000000
40 10.355800 10.356000 0.000200 0.414062 0.414062 0.000000
41 10.359100 10.358800 0.000300 0.351562 0.349609 0.001953
42 10.363100 10.362700 0.000400 0.335938 0.335938 0.000000
43 10.329000 10.329400 0.000400 0.458984 0.460938 0.001953
44 10.352700 10.353000 0.000300 0.357422 0.359375 0.001953
45 10.362200 10.361900 0.000300 0.343750 0.341797 0.001953
46 10.319600 10.319500 0.000100 0.488281 0.488281 0.000000
47 10.348700 10.348500 0.000200 0.367188 0.367188 0.000000
48 10.342400 10.342000 0.000400 0.359375 0.361328 0.001953
49 10.321900 10.322000 0.000100 0.486328 0.486328 0.000000
50 10.368800 10.368500 0.000300 0.417969 0.417969 0.000000
  • n_loop_it=2
loss grad_norm
fused_cel no-fused absdiff fused_cel no-fused absdiff
1 10.369300 10.369300 0.000000 0.375000 0.375000 0.000000
2 10.383600 10.383600 0.000000 0.408203 0.408203 0.000000
3 10.374700 10.374800 0.000100 0.408203 0.408203 0.000000
4 10.379800 10.379900 0.000100 0.335938 0.335938 0.000000
5 10.376600 10.376600 0.000000 0.353516 0.353516 0.000000
6 10.363300 10.363300 0.000000 0.457031 0.457031 0.000000
7 10.378900 10.378900 0.000000 0.326172 0.326172 0.000000
8 10.372100 10.372000 0.000100 0.324219 0.324219 0.000000
9 10.359900 10.360000 0.000100 0.460938 0.460938 0.000000
10 10.369400 10.369300 0.000100 0.343750 0.343750 0.000000
11 10.377400 10.377000 0.000400 0.322266 0.322266 0.000000
12 10.362600 10.362600 0.000000 0.365234 0.365234 0.000000
13 10.358400 10.358700 0.000300 0.384766 0.384766 0.000000
14 10.363000 10.362900 0.000100 0.345703 0.345703 0.000000
15 10.367900 10.368100 0.000200 0.337891 0.337891 0.000000
16 10.360100 10.360100 0.000000 0.378906 0.378906 0.000000
17 10.367700 10.367700 0.000000 0.326172 0.326172 0.000000
18 10.362300 10.362100 0.000200 0.562500 0.566406 0.003906
19 10.359400 10.359100 0.000300 0.343750 0.345703 0.001953
20 10.363100 10.362900 0.000200 0.335938 0.335938 0.000000
21 10.352100 10.352300 0.000200 0.380859 0.380859 0.000000
22 10.365000 10.365000 0.000000 0.328125 0.330078 0.001953
23 10.364900 10.365000 0.000100 0.363281 0.363281 0.000000
24 10.352200 10.352500 0.000300 0.365234 0.365234 0.000000
25 10.346000 10.346100 0.000100 0.451172 0.451172 0.000000
26 10.354100 10.353800 0.000300 0.371094 0.371094 0.000000
27 10.351000 10.350800 0.000200 0.382812 0.384766 0.001953
28 10.363100 10.363300 0.000200 0.359375 0.359375 0.000000
29 10.355300 10.355300 0.000000 0.359375 0.361328 0.001953
30 10.361700 10.360500 0.001200 0.341797 0.341797 0.000000
31 10.358700 10.358900 0.000200 0.351562 0.349609 0.001953
32 10.358700 10.358900 0.000200 0.337891 0.333984 0.003906
33 10.357800 10.358300 0.000500 0.343750 0.343750 0.000000
34 10.339400 10.339300 0.000100 0.425781 0.425781 0.000000
35 10.339500 10.339200 0.000300 0.408203 0.408203 0.000000
36 10.363700 10.364000 0.000300 0.304688 0.304688 0.000000
37 10.339900 10.340100 0.000200 0.402344 0.402344 0.000000
38 10.356700 10.356700 0.000000 0.345703 0.345703 0.000000
39 10.339200 10.339200 0.000000 0.480469 0.478516 0.001953
40 10.355300 10.356000 0.000700 0.414062 0.414062 0.000000
41 10.359000 10.358800 0.000200 0.351562 0.349609 0.001953
42 10.362900 10.362700 0.000200 0.333984 0.335938 0.001953
43 10.328600 10.329400 0.000800 0.460938 0.460938 0.000000
44 10.353200 10.353000 0.000200 0.359375 0.359375 0.000000
45 10.362200 10.361900 0.000300 0.343750 0.341797 0.001953
46 10.319600 10.319500 0.000100 0.486328 0.488281 0.001953
47 10.348400 10.348500 0.000100 0.365234 0.367188 0.001953
48 10.342500 10.342000 0.000500 0.361328 0.361328 0.000000
49 10.321700 10.322000 0.000300 0.486328 0.486328 0.000000
50 10.369700 10.368500 0.001200 0.419922 0.417969 0.001953
  • n_loop_it=4
loss grad_norm
fused_cel no-fused absdiff fused_cel no-fused absdiff
1 10.369300 10.369300 0.000000 0.375000 0.375000 0.000000
2 10.383600 10.383600 0.000000 0.406250 0.408203 0.001953
3 10.374700 10.374800 0.000100 0.408203 0.408203 0.000000
4 10.379900 10.379900 0.000000 0.335938 0.335938 0.000000
5 10.376600 10.376600 0.000000 0.353516 0.353516 0.000000
6 10.363300 10.363300 0.000000 0.457031 0.457031 0.000000
7 10.378900 10.378900 0.000000 0.326172 0.326172 0.000000
8 10.372100 10.372000 0.000100 0.324219 0.324219 0.000000
9 10.360000 10.360000 0.000000 0.460938 0.460938 0.000000
10 10.369400 10.369300 0.000100 0.343750 0.343750 0.000000
11 10.377300 10.377000 0.000300 0.322266 0.322266 0.000000
12 10.362500 10.362600 0.000100 0.365234 0.365234 0.000000
13 10.358500 10.358700 0.000200 0.384766 0.384766 0.000000
14 10.362900 10.362900 0.000000 0.345703 0.345703 0.000000
15 10.367800 10.368100 0.000300 0.337891 0.337891 0.000000
16 10.360000 10.360100 0.000100 0.380859 0.378906 0.001953
17 10.367800 10.367700 0.000100 0.326172 0.326172 0.000000
18 10.362200 10.362100 0.000100 0.562500 0.566406 0.003906
19 10.359300 10.359100 0.000200 0.343750 0.345703 0.001953
20 10.363000 10.362900 0.000100 0.335938 0.335938 0.000000
21 10.352000 10.352300 0.000300 0.380859 0.380859 0.000000
22 10.364900 10.365000 0.000100 0.330078 0.330078 0.000000
23 10.364800 10.365000 0.000200 0.363281 0.363281 0.000000
24 10.352200 10.352500 0.000300 0.365234 0.365234 0.000000
25 10.346400 10.346100 0.000300 0.451172 0.451172 0.000000
26 10.354200 10.353800 0.000400 0.371094 0.371094 0.000000
27 10.351000 10.350800 0.000200 0.384766 0.384766 0.000000
28 10.363000 10.363300 0.000300 0.359375 0.359375 0.000000
29 10.355300 10.355300 0.000000 0.361328 0.361328 0.000000
30 10.361400 10.360500 0.000900 0.341797 0.341797 0.000000
31 10.358500 10.358900 0.000400 0.351562 0.349609 0.001953
32 10.358900 10.358900 0.000000 0.339844 0.333984 0.005859
33 10.358000 10.358300 0.000300 0.343750 0.343750 0.000000
34 10.339300 10.339300 0.000000 0.425781 0.425781 0.000000
35 10.339300 10.339200 0.000100 0.408203 0.408203 0.000000
36 10.363800 10.364000 0.000200 0.304688 0.304688 0.000000
37 10.340000 10.340100 0.000100 0.402344 0.402344 0.000000
38 10.356500 10.356700 0.000200 0.345703 0.345703 0.000000
39 10.338800 10.339200 0.000400 0.478516 0.478516 0.000000
40 10.356000 10.356000 0.000000 0.416016 0.414062 0.001953
41 10.358800 10.358800 0.000000 0.349609 0.349609 0.000000
42 10.362800 10.362700 0.000100 0.335938 0.335938 0.000000
43 10.328900 10.329400 0.000500 0.460938 0.460938 0.000000
44 10.353000 10.353000 0.000000 0.359375 0.359375 0.000000
45 10.361400 10.361900 0.000500 0.343750 0.341797 0.001953
46 10.320000 10.319500 0.000500 0.486328 0.488281 0.001953
47 10.348200 10.348500 0.000300 0.365234 0.367188 0.001953
48 10.342200 10.342000 0.000200 0.361328 0.361328 0.000000
49 10.322400 10.322000 0.000400 0.486328 0.486328 0.000000
50 10.369200 10.368500 0.000700 0.419922 0.417969 0.001953

Training metrics for bfloat16

step trainable_params total_params n_loop_iters total_flos train_loss train_mem_gpu_peaked_delta train_samples_per_second train_steps_per_second train_runtime
no-fused 50 1032272 1032272 1 74GF 10.3582 188MB 24.8 12.4 0:00:04.03
fused_cel 50 1032272 1032272 1 74GF 10.3582 128MB 24.564 12.282 0:00:04.07
fused_cel 50 1032272 1032272 2 74GF 10.3582 98MB 29.51 14.755 0:00:03.38
fused_cel 50 1032272 1032272 4 74GF 10.3582 49MB 31.764 15.882 0:00:03.14

Next Steps

  • Integrate with FastLanguageModel
  • Run tests / benchmarks on LoRA and QLoRA configs

@danielhanchen
Copy link
Contributor

Thanks Jerome and fantastic work - will check this!

@jeromeku
Copy link
Author

jeromeku commented May 13, 2024

@danielhanchen

This is very much a WIP. The PR description isn't rendering nicely -- here's a more readable version (located in the test folder).

Aside from integrating with unsloth's custom language models, need to check how kernel interplays with grad accum and grad checkpointing.

Also will likely need to rebase to tidy up the commit history.

@jeromeku
Copy link
Author

jeromeku commented May 15, 2024

@danielhanchen

Additional benchmarking for Llama3-8b with and without FastLanguageModel:

Benchmark config:

{
    "max_steps": 50,
    "dtype": "bfloat16",
    "model_id": "meta-llama/Meta-Llama-3-8B",
    "batch_size": 2,
    "max_seq_len": 512,
    "packing": true,
    "grad_accum_steps": 1,
    "load_in_4bit": true,
    "use_lora": true,
    "fused_cel_n_loop_iters": [
        1,
        2,
        4,
        8
    ]
}

QLoRA config:

- lora_alpha=16,
- lora_dropout=0.0,
- bias="none"
- task_type="CAUSAL_LM"
- BitsAndBytesConfig(
        load_in_4bit=load_in_4bit,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
        bnb_4bit_compute_dtype=bfloat16,
    )
  • transformers peft results
model step trainable_params total_params n_loop_iters total_flos train_loss train_mem_gpu_peaked_delta train_samples_per_second train_steps_per_second train_runtime
no-fused 50 20971520 4561571840 1 2153176GF 6.5549 14237MB 0.591 0.295 0:02:49.32
fused_cel 50 20971520 4561571840 1 2153176GF 6.5545 13783MB 0.363 0.182 0:04:35.26
fused_cel 50 20971520 4561571840 2 2153176GF 6.5492 13507MB 0.599 0.3 0:02:46.84
fused_cel 50 20971520 4561571840 4 2153176GF 6.554 13129MB 0.59 0.295 0:02:49.39
fused_cel 50 20971520 4561571840 8 2153176GF 6.5667 13047MB 0.587 0.293 0:02:50.40
  • unsloth FastLanguageModel results
model steps trainable_params total_params n_loop_iters total_flos train_loss train_mem_gpu_peaked_delta train_samples_per_second train_steps_per_second train_runtime
no-fused 50 20971520 4561571840 1 2153176GF 1.1378 4604MB 1.408 0.704 0:01:11.01
fused_cel 50 20971520 4561571840 1 2153176GF 1.1378 4884MB 1.38 0.69 0:01:12.45
fused_cel 50 20971520 4561571840 2 2153176GF 1.5245 4613MB 1.457 0.729 0:01:08.62
fused_cel 50 20971520 4561571840 4 2153176GF 1.5245 4237MB 1.441 0.721 0:01:09.38
fused_cel 50 20971520 4561571840 8 2153176GF 1.5245 4048MB 1.442 0.721 0:01:09.36

Some observations:

  • There are significant memory savings when using fused_cel with transformers peft at n_loop_iters > 2 with higher throughput and practically the same final training loss
  • The savings are less dramatic when using FastLanguageModel with noticeably higher (worse) final training loss at higher n_loop_iters.

@danielhanchen
Copy link
Contributor

@jeromeku Thanks for testing again! Hmm weird on the training loss being noticeable higher hmmm that is really really weird

I can understand why the VRAM reduction are less pronounced, but unsure on the training loss issues

@jeromeku
Copy link
Author

@danielhanchen

Trying to figure out why the divergence when using FastLanguageModel vs. the vanilla transformers peft model, as can be seen by the final training loss for the two tables.

When using transformers, we get both memory savings and practically the same loss while with FastLanguageModel there is significant deviation in the final loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants