keras:如何通过model.train_on_batch()使用学习率衰减

时间:2018-12-04 09:47:12

标签: python keras deep-learning

在我当前的项目中,我使用keras的train_on_batch()函数进行训练,因为fit()函数不支持GAN生成器和鉴别器的交替训练。 使用(例如)亚当优化器,我必须使用optimizer = Adam(decay=my_decay)在构造函数中指定学习率衰减,并将其交给模型编译方法。 如果以后使用模型的fit()函数,此方法就可以正常工作,因为它可以在内部计算训练重复次数,但是我不知道如何使用类似的构造自己设置该值

counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        loss = model.train_on_batch(x, y)    # how to use the current learning rate?

具有一些计算学习率的功能。 如何手动设置当前学习率?

如果这篇文章中有错误,对不起,这是我的第一个问题。

已经感谢您的帮助。

1 个答案:

答案 0 :(得分:5)

借助keras后端设置值:keras.backend.set_value(model.optimizer.lr, lr)(其中lr是浮点数,期望的学习率)适用于fit方法,并且适用于train_on_batch:

from keras import backend as K


counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        K.set_value(model.optimizer.lr, current_learning_rate)  # set new lr
        loss = model.train_on_batch(x, y) 

希望有帮助!