我尝试在modelcheckpoint
回调的基础上创建自己的tf.keras回调。
回调将模型另存为Tensorflow pb而不是hdf5。因此,我使用了新的Tensorflow功能
tf.contrib.saved_model.save_keras_model(self.model, filepath)
不幸的是,我被困住了,感谢您的帮助。
我按小部分修改了回调函数,如下所示:
self.model.save(filepath, overwrite=True)
更改为self.tf.contrib.saved_model.save_keras_model(self.model, filepath)
但是我收到以下错误
TypeError:super(type,obj):obj必须是类型的实例或子类型
在super(ModelCheckpoint, self).__init__()
这似乎是一个简单的错误(或该方法的主要缺陷),但是我目前无法解决。感谢您的帮助。
p.s。您会尝试将其放入callback.py(我尝试失败)还是直接在我的代码中调用它?
代码如下:
class pb_checkpoint(tf.keras.callbacks.Callback):
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False,
mode='auto', period=1):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.period = period
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, batch, logs={}):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
self.tf.contrib.saved_model.save_keras_model(self.model, filepath)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
else:
self.tf.contrib.saved_model.save_keras_model(self.model, filepath)
pb_Callback = pb_checkpoint(filepath=filepath, monitor='val_loss',
verbose=0,
save_best_only=True,
mode='auto',
period=1,
)