很抱歉,这个问题似乎很简单。但是请阅读Keras的保存和恢复帮助页面:
https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models
我不知道如何在训练期间使用“ ModelCheckpoint”进行保存。帮助文件提到应该提供3个文件,我只能看到一个文件,MODEL.ckpt。
这是我的代码:
checkpoint_dir = FolderName + "/tmp/model.ckpt"
cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=True)
parallel_model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),loss=my_cost_MSE, metrics=['accuracy])
parallel _model.fit(image, annotation, epochs=epoch,
batch_size=batch_size, steps_per_epoch=10,
validation_data=(image_val,annotation_val),validation_steps=num_batch_val,callbacks=callbacks_list)
此外,当我想通过以下方法训练负重时:
model = k.models.load_model(file_checkpoint)
我得到了错误:
"raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
ValueError: Unknown loss function:my_cost_MSE"
my-cost_MSE是我在培训中使用的成本函数。
答案 0 :(得分:1)
keras
有一个save
命令。它保存了重建模型所需的所有细节。
(来自keras docs)
from keras.models import load_model
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del model # deletes the existing model
# returns am identical compiled model
model = load_model('my_model.h5')
答案 1 :(得分:0)
首先,看起来您正在使用tf.keras
(来自tensorflow)实现,而不是keras
(来自keras-team / keras存储库)。请注意,在这种情况下,如tf.keras guide中所述:
保存模型的权重时,tf.keras默认为检查点 格式。传递save_format ='h5'以使用HDF5。
另一方面,请注意,添加回调ModelCheckpoint
通常大致相当于调用model.save(...)
,因此这就是为什么您希望保存三个文件的原因(根据{{3 }}。
之所以不这样做,是因为通过使用选项save_weights_only=True
,您仅节省了权重。大致等效于在每个时期结束时将对model.save
的调用model.save_weights
替换为model.load_weights
。因此,您唯一要保存的文件就是具有权重的文件。
请注意,如果您只想存储权重,则需要预先加载模型(例如结构),然后调用model = MyModel(...) # Your model definition as used in training
model.load_weights(file_checkpoint)
:
my_cost_MSE
请注意,在这种情况下,自定义定义(cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=False)
parallel_model.compile(
optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
loss=my_cost_MSE,
metrics=['accuracy]
)
# Training code here
)不会有问题,因为您只是在加载模型权重。
另一种进行方法是存储整个模型并相应地加载它:
model = k.models.load_model(file_checkpoint, custom_objects={"my_cost_MSE": my_cost_MSE})
然后您可以通过以下方式加载它:
custom_objects
请注意,在后一种情况下,您需要指定swagger-blocks
,因为需要对其定义进行反序列化。