TensorFlow 2.0-学习率调度器

时间:2020-03-07 17:48:09

标签: python-3.x deep-learning tensorflow2.0 learning-rate

我正在使用Python 3.7和TensorFlow 2.0,我必须使用以下学习率调度程序为160个纪元训练神经网络:

在80和120个时代将学习率降低10倍,其中初始学习率= 0.01。

我该如何编写函数以合并此学习率调度程序:

def scheduler(epoch):
    if epoch < 80:
        return 0.01
    elif epoch >= 80 and epoch < 120:
        return 0.01 / 10
    elif epoch >= 120:
        return 0.01 / 100

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

model.fit(
    x = data, y = labels,
    epochs=100, callbacks=[callback],
    validation_data=(val_data, val_labels))

这是正确的实现吗?

谢谢!

1 个答案:

答案 0 :(得分:1)

tf.keras.callbacks.LearningRateScheduler需要一个以纪元索引作为输入(整数,从0开始索引)并返回新的学习率作为输出(浮点数)的函数:

def scheduler(epoch, current_learning_rate):
    if epoch == 79 or epoch == 119:
        return current_learning_rate / 10
    else:
        return min(current_learning_rate, 0.001)

这将在第80和120阶段将学习率降低10倍,并使其保持在所有其他时期不变。