Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Latest commit

 

History

History
47 lines (32 loc) · 1.18 KB

README.zh-CN.md

File metadata and controls

47 lines (32 loc) · 1.18 KB

Keras RAdam

Version License

[中文|English]

RAdam的非官方实现。

安装

pip install keras-rectified-adam

外部链接

使用

from tensorflow import keras
import numpy as np
from keras_radam import RAdam

# 构建一个使用RAdam优化器的简单模型
model = keras.models.Sequential()
model.add(keras.layers.Dense(input_shape=(17,), units=3))
model.compile(RAdam(), loss='mse')

# 构建简单数据
x = np.random.standard_normal((4096 * 30, 17))
w = np.random.standard_normal((17, 3))
y = np.dot(x, w)

# 开始训练
model.fit(x, y, epochs=5)

使用Warmup

from keras_radam import RAdam

RAdam(total_step=10000, warmup_proportion=0.1, min_lr=1e-5)