Skip to content

How to do JAX mixed precision properly? #25434

Answered by homerjed
FirstQuadrantSam asked this question in Q&A
Discussion options

You must be logged in to vote

Not a full answer but this depends on a number of choices / conventions.

I looked at jmp to implement something like this myself. Note the casting (like you are doing) for the grads / data / model.

Some operations need to be kept in float32, such as attention softmax'ing.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by FirstQuadrantSam
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants