diff --git a/optimization.py b/optimization.py index d33dabd91..41217e3f3 100644 --- a/optimization.py +++ b/optimization.py @@ -30,7 +30,7 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): # Implements linear decay of the learning rate. learning_rate = tf.train.polynomial_decay( - learning_rate, + learning_rate * num_train_steps / (num_train_steps - num_warmup_steps), global_step, num_train_steps, end_learning_rate=0.0,