What means the step
in the update function?
#5972
-
Dear all, I also have a question about the jax.experimental.optimizers module. In the But if I'm training a 2-layer vanilla CNN on MNIST for And will this Thank you in advance for your time! :) Best, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I'd suggest reading the code of the optimizers module: it's pretty readable on the whole. Step numbers are used for at least two reasons:
These are both standard and not specific to JAX; the only slightly unusual thing about the JAX optimizers is that they make this step number calculation explicit and under your control. I'll also note that Optax has largely subsumed |
Beta Was this translation helpful? Give feedback.
I'd suggest reading the code of the optimizers module: it's pretty readable on the whole.
Step numbers are used for at least two reasons:
sgd
(https://cs.opensource.google/jax/jax/+/master:jax/experimental/optimizers.py;drc=b260468b51efff40183796e98b844314b66f7686;l=246 ) it computes astep_size
based on the current step number.