Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance gap between manual nvfuser definition and thunder.jit #3629

Open
Priya2698 opened this issue Dec 20, 2024 · 3 comments
Open

Performance gap between manual nvfuser definition and thunder.jit #3629

Priya2698 opened this issue Dec 20, 2024 · 3 comments

Comments

@Priya2698
Copy link
Collaborator

Priya2698 commented Dec 20, 2024

I am seeing lower performance for thunder.jit (with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.

For size = (2048, 8192), dtype=torch.bfloat16 (on my local system with Ada card):

--------------------------------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                      Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.bfloat16-size=[2048_8192]]                             136.8970 (1.0)      145.0250 (1.0)      140.3881 (1.0)      2.3346 (1.68)     140.0140 (1.0)      3.4200 (1.84)          2;0        7.1231 (1.0)          10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.bfloat16-size=[2048_8192]-executor='thunder']     223.9020 (1.64)     228.9010 (1.58)     226.1649 (1.61)     1.3899 (1.0)      226.0655 (1.61)     1.8540 (1.0)           2;0        4.4216 (0.62)         10           1
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.float32-size=[2048_8192]]                              256.4510 (1.87)     265.5080 (1.83)     260.5773 (1.86)     3.0545 (2.20)     259.8870 (1.86)     4.8270 (2.60)          4;0        3.8376 (0.54)         10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.float32-size=[2048_8192]-executor='thunder']      271.0090 (1.98)     274.9130 (1.90)     273.3845 (1.95)     1.4553 (1.05)     273.9035 (1.96)     2.7580 (1.49)          5;0        3.6579 (0.51)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

The above numbers are using rmsnorm composed of primitives:

def rmsnorm_prims(inputs: list):
    inp, weights = inputs
    squared_mean = (inp**2).mean(1, keepdim=True)
    rms_eps = torch.sqrt(squared_mean + 1e-5)
    output = weights * (inp / rms_eps)
    return output

I recover some of the performance using torch.nn.functional.rms_norm (Note that the manual nvfuser definition was generated through Thunder using the above rmsnorm_prims):

def rmsnorm_func(inputs: list):
    inp, weights = inputs
    output = F.rms_norm(inp, inp.shape[1:], weights, eps=1e-5)
    return output
--------------------------------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                      Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.bfloat16-size=[2048_8192]]                             137.6300 (1.0)      143.1660 (1.0)      140.5177 (1.0)      1.9396 (1.91)     139.8885 (1.0)      3.2640 (2.55)          4;0        7.1165 (1.0)          10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.bfloat16-size=[2048_8192]-executor='thunder']     175.1710 (1.27)     178.3350 (1.25)     176.9573 (1.26)     1.0168 (1.0)      176.9435 (1.26)     1.2810 (1.0)           4;0        5.6511 (0.79)         10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.float32-size=[2048_8192]-executor='thunder']      255.0390 (1.85)     264.3810 (1.85)     258.7758 (1.84)     2.6816 (2.64)     258.5290 (1.85)     2.9120 (2.27)          3;0        3.8643 (0.54)         10           1
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.float32-size=[2048_8192]]                              258.3390 (1.88)     267.1710 (1.87)     261.6898 (1.86)     2.7284 (2.68)     261.1510 (1.87)     2.7240 (2.13)          4;1        3.8213 (0.54)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@Priya2698
Copy link
Collaborator Author

Priya2698 commented Dec 20, 2024

I have a (mostly) standalone script for nsys profiling here.

I'll run a sweep using F.rms_norm. The existent fusion definition in the python benchmarks was obtained using Thunder but modified to allow for dynamic shapes and dtypes. Some casts and broadcast ops may have been simplified, which may be responsible for the performance gap.
I'll looking at the difference in the operators present.

@Priya2698
Copy link
Collaborator Author

CC: @kevinstephano @mruberry

@mruberry
Copy link

I filed Lightning-AI/lightning-thunder#1582 to also track this in the thunder repository. Looking forward to hearing the results of your analysis, @Priya2698!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants