我正在使用train_on_batch
批量培训我的数据,但似乎train_on_batch
没有使用回调的选项,这似乎是使用检查点的必要条件。
我无法使用model.fit
,因为这似乎要求我将所有数据加载到内存中。
model.fit_generator
给了我奇怪的问题(比如挂在一个时代的末尾)。
以下是来自Keras API文档的示例,其中显示了ModelCheckpoint
的使用:
from keras.callbacks import ModelCheckpoint
model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
checkpointer = ModelCheckpoint(filepath='/tmp/weights.hdf5', verbose=1,
save_best_only=True)
model.fit(x_train, y_train, batch_size=128, epochs=20,
verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])
答案 0 :(得分:0)
如果您手动对每个批次进行训练,则可以在任何#epoch(#batch)处执行所需的操作。无需使用回调,只需调用model.save
或model.save_weights
。