Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fused RMSNorm a registered op #199

Open
lessw2020 opened this issue Apr 5, 2024 · 1 comment
Open

Make fused RMSNorm a registered op #199

lessw2020 opened this issue Apr 5, 2024 · 1 comment
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@lessw2020
Copy link
Contributor

Adding this as tracking issue to unblock #181 from landing:
per @wanchaol :
IMO we should also register the fwd/bwd rmsnorm kernel as a PyTorch op, this is so that:

making it a custom op makes it compatible with PT2, which I believe it's currently graph breaking on the FusedRMSNorm path if we turn on torch.compile
it allows other components (i.e. DTensor) to provide sharding rule to this custom op so that it would compatible with the tensor parallelism

@tianyu-l
Copy link
Contributor

tianyu-l commented May 8, 2024

update: Hit IMA issues for both my implementation #296 and @wconstab's #303. Working on debugging with @lessw2020 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants