You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am seeing lower performance for thunder.jit (with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.
For size = (2048, 8192), dtype=torch.bfloat16 (on my local system with Ada card):
I recover some of the performance using torch.nn.functional.rms_norm (Note that the manual nvfuser definition was generated through Thunder using the above rmsnorm_prims):
I have a (mostly) standalone script for nsys profiling here.
I'll run a sweep using F.rms_norm. The existent fusion definition in the python benchmarks was obtained using Thunder but modified to allow for dynamic shapes and dtypes. Some casts and broadcast ops may have been simplified, which may be responsible for the performance gap.
I'll looking at the difference in the operators present.
I am seeing lower performance for
thunder.jit
(with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.For
size = (2048, 8192), dtype=torch.bfloat16
(on my local system with Ada card):The above numbers are using rmsnorm composed of primitives:
I recover some of the performance using
torch.nn.functional.rms_norm
(Note that the manual nvfuser definition was generated through Thunder using the abovermsnorm_prims
):The text was updated successfully, but these errors were encountered: