我正在训练神经网络,并希望每N个epoch节省模型权重以进行预测。我提出了此代码草案,它受@grovina的回复here的启发。你能提出建议吗? 预先感谢。
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, model, N):
self.model = model
self.N = N
self.epoch = 0
def on_batch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = 'weights%08d.h5' % self.epoch
self.model.save_weights(name)
self.epoch += 1
然后将其添加到fit调用中:每5个时代保存一次权重:
model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])
答案 0 :(得分:7)
您不需要为回调传递模型。它已经可以通过它的超级访问模型了。因此,删除__init__(..., model, ...)
自变量和self.model = model
。无论如何,您都应该能够通过self.model
访问当前模型。您还将它保存在每个批次的末尾,这不是您想要的,您可能希望它为on_epoch_end
。
但是无论如何,您可以通过幼稚的modelcheckpoint callback完成操作。您无需编写自定义代码。您可以按以下方式使用它;
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
答案 1 :(得分:1)
您应该在on_epoch_end上实现,而不是在on_batch_end上实现。并且将模型作为__init__
的参数传递也是多余的。
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, N):
self.N = N
self.epoch = 0
def on_epoch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = 'weights%08d.h5' % self.epoch
self.model.save_weights(name)
self.epoch += 1