Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention recompute #20603

Merged
merged 20 commits into from
May 21, 2024
Merged

Flash attention recompute #20603

merged 20 commits into from
May 21, 2024

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented May 8, 2024

Flash attn recompute

  1. Allow PythonOp(FlashAttn) can be recomputed correctly. 45879ff
  2. Use JSON to pass the selected-to-recompute subgraphs. 3c374da

Better Memory Efficiency

Customer model can run both PyTorch SPDA and Flash Attn, this PR make it possible to let the Flash Attn path work with ORTModule layerwise recompute. The peak drop from 45.xGB to 32.xGB if we only compare the layers (not including other pieces, BTW there are few more optimization targeting other pieces as well later).

Better Perf

Using Flash ATTN bring additionally 16% end to end time reduction, with highly aligned loss curve.

image

Use JSON File to pass Recompute Plans

To overcome the limitation of max length of the strings defined in session options.

Motivation and Context

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label May 8, 2024
@pengwa pengwa requested review from wschin and zhijxu-MS May 8, 2024 08:24
@pengwa pengwa changed the title Flash attn recompute Flash attention recompute May 8, 2024
guyang3532
guyang3532 previously approved these changes May 13, 2024
@pengwa pengwa merged commit 8a98874 into main May 21, 2024
98 checks passed
@pengwa pengwa deleted the pengwa/flash_attn_recompute branch May 21, 2024 05:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants