使用Keras中的train_on_batch保存检查点

时间:2017-11-26 23:28:15

标签: python deep-learning keras

我正在使用train_on_batch批量培训我的数据,但似乎train_on_batch没有使用回调的选项,这似乎是使用检查点的必要条件。

我无法使用model.fit,因为这似乎要求我将所有数据加载到内存中。

model.fit_generator给了我奇怪的问题(比如挂在一个时代的末尾)。

以下是来自Keras API文档的示例,其中显示了ModelCheckpoint的使用:

from keras.callbacks import ModelCheckpoint

model = Sequential() 
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

checkpointer = ModelCheckpoint(filepath='/tmp/weights.hdf5', verbose=1, 
                               save_best_only=True)
model.fit(x_train, y_train, batch_size=128, epochs=20, 
          verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])

1 个答案:

答案 0 :(得分:0)

如果您手动对每个批次进行训练,则可以在任何#epoch(#batch)处执行所需的操作。无需使用回调,只需调用model.savemodel.save_weights