How to do JAX mixed precision properly? #25434
-
Hi All!
and then send the gradient calculated to the optimizers etc. It turns out that this actually gives a slowdown compared to the high-precision version. Note that I used jax.block_until_ready, and only timed the gradient calculation. Then I tried the following:
And it works well with the random data, giving the desired 2x acceleration. Any ideas/suggestions of what could be wrong? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Not a full answer but this depends on a number of choices / conventions. I looked at Some operations need to be kept in float32, such as attention softmax'ing. |
Beta Was this translation helpful? Give feedback.
Not a full answer but this depends on a number of choices / conventions.
I looked at
jmp
to implement something like this myself. Note the casting (like you are doing) for the grads / data / model.Some operations need to be kept in float32, such as attention softmax'ing.