Keras恢复了关于新架构的培训

时间:2019-03-12 14:40:57

标签: python tensorflow optimization keras

我正在尝试对简单的NN体系结构进行修改后再进行训练。 在此示例中,使用MNIST数据集,训练了2个时期的模型,然后将模型保存在.h5文件中。然后,我重新加载模型,修改模型并重新编译。但是我想从我停下来的地方继续训练。因此,我想使用重载模型的优化器来继续训练。这是代码:

MLP = keras.models.Sequential([
        keras.layers.Dense(100, activation='sigmoid', input_shape=(784,)),
        keras.layers.Dense(10, activation='softmax')
    ])
MLP.compile(optimizer=tf.keras.optimizers.Adam(lr=0.5), loss=tf.losses.log_loss, metrics=['accuracy'])
training_output = MLP.fit(x_train, y_train, epochs=2, validation_data=(x_val, y_val), verbose=2, initial_epoch=0)

MLP.save('test.h5')



MLP = keras.models.load_model('test.h5', custom_objects={'log_loss': log_loss})
modelt = MLP
modelt = # update modelt architecture


modelt.compile(optimizer=MLP.optimizer, loss=tf.losses.log_loss, metrics=['accuracy'])
training_output = modelt.fit(x_train, y_train, epochs=4, validation_data=(x_val, y_val), verbose=2, initial_epoch=2)

问题在于:

Epoch 1/2

 - 1s - loss: 1.2690 - acc: 0.2216 - val_loss: 1.3097 - val_acc: 0.2095

Epoch 2/2

 - 1s - loss: 1.2859 - acc: 0.2030 - val_loss: 1.2420 - val_acc: 0.1760

Epoch 3/4

 - 1s - loss: 2.8945 - acc: 0.0993 - val_loss: 2.9367 - val_acc: 0.0890

Epoch 4/4

 - 1s - loss: 2.9035 - acc: 0.0993 - val_loss: 2.9367 - val_acc: 0.0890

培训依旧停滞不前,甚至更糟。我该如何解决这个问题?即使试图实例化一个新的Adam Optimizer对象并重置我可以触摸的所有值,其行为也不会改变。重用重新加载的优化器的正确程序是什么?

非常感谢您的帮助!

编辑: 简单保存并重新加载模型将返回以下结果:

lr begin: 0.15811755
begin: [0.35258077597618104, 0.1265]
lr end: 0.25961164
end: [1.0754492826461792, 0.2785]
-------------------------------------------------------------
lr begin: 0.25961164
begin: [1.0754492826461792, 0.2785]
lr end: 0.34131044
end: [1.5968322057723998, 0.2185]
-------------------------------------------------------------
lr begin: 0.34131044
begin: [1.5968322057723998, 0.2185]
lr end: 0.3903688
end: [2.8819153175354004, 0.106]
-------------------------------------------------------------
lr begin: 0.3903688
begin: [2.8819153175354004, 0.106]
lr end: 0.42264876
end: [2.8819153175354004, 0.106]

0 个答案:

没有答案