回调函数中的TypeError

时间:2019-05-21 14:20:38

标签: python tensorflow callback tf.keras

我尝试在modelcheckpoint回调的基础上创建自己的tf.keras回调。

回调将模型另存为Tensorflow pb而不是hdf5。因此,我使用了新的Tensorflow功能

tf.contrib.saved_model.save_keras_model(self.model, filepath)

不幸的是,我被困住了,感谢您的帮助。

我按小部分修改了回调函数,如下所示:

  • 我将保存部分从self.model.save(filepath, overwrite=True)更改为self.tf.contrib.saved_model.save_keras_model(self.model, filepath)
  • 第二,我删除了self.save_weights_only并覆盖了它,因为新的save_keras_model没有这些功能。

但是我收到以下错误

  

TypeError:super(type,obj):obj必须是类型的实例或子类型

super(ModelCheckpoint, self).__init__()

这似乎是一个简单的错误(或该方法的主要缺陷),但是我目前无法解决。感谢您的帮助。

p.s。您会尝试将其放入callback.py(我尝试失败)还是直接在我的代码中调用它?


代码如下:

class pb_checkpoint(tf.keras.callbacks.Callback):
    def __init__(self, filepath, monitor='val_loss', verbose=0,
                 save_best_only=False,
                 mode='auto', period=1):
        super(ModelCheckpoint, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.period = period
        self.epochs_since_last_save = 0

    if mode not in ['auto', 'min', 'max']:
        warnings.warn('ModelCheckpoint mode %s is unknown, '
                      'fallback to auto mode.' % (mode),
                      RuntimeWarning)
        mode = 'auto'

    if mode == 'min':
        self.monitor_op = np.less
        self.best = np.Inf
    elif mode == 'max':
        self.monitor_op = np.greater
        self.best = -np.Inf
    else:
        if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            self.monitor_op = np.less
            self.best = np.Inf

def on_epoch_end(self, batch, logs={}):
    logs = logs or {}
    self.epochs_since_last_save += 1
    if self.epochs_since_last_save >= self.period:
        self.epochs_since_last_save = 0
        filepath = self.filepath.format(epoch=epoch + 1, **logs)
        if self.save_best_only:
            current = logs.get(self.monitor)
            if current is None:
                warnings.warn('Can save best model only with %s available, '
                              'skipping.' % (self.monitor), RuntimeWarning)
            else:
                if self.monitor_op(current, self.best):
                    if self.verbose > 0:
                        print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                              ' saving model to %s'
                              % (epoch + 1, self.monitor, self.best,
                                 current, filepath))
                    self.best = current  
                    self.tf.contrib.saved_model.save_keras_model(self.model, filepath)

                else:
                    if self.verbose > 0:
                        print('\nEpoch %05d: %s did not improve from %0.5f' %
                              (epoch + 1, self.monitor, self.best))
        else:
            if self.verbose > 0:
                print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
            else:
                self.tf.contrib.saved_model.save_keras_model(self.model, filepath)

pb_Callback = pb_checkpoint(filepath=filepath, monitor='val_loss',
                                verbose=0, 
                                save_best_only=True, 
                                mode='auto', 
                                period=1,
                                )

0 个答案:

没有答案