Keras:ModelCheckpoint监视另一个回调的输出

时间:2019-10-22 11:45:52

标签: python keras callback

我有一个自定义回调,向我显示了时代末期的假阳性和真阳性的数量。我想使用ModelCheckpoint来保存最大正负误报数的模型。我尝试了以下代码,但似乎无法正常工作:

  

RuntimeWarning:仅在可用的tpfp(跳过)下才能保存最佳模型。

有人知道该怎么做吗?
谢谢你

class tpfp(keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs={}):
        x_test=self.validation_data[0]
        y_test=self.validation_data[1]
        y_pred=self.model.predict(x_test,verbose=0)
        y_pred[y_pred>.6]=1  #change threshold here
        y_pred[y_pred<1] = 0
        cm=metrics.confusion_matrix(y_test,y_pred)
        fp=cm[0,1]
        tp=cm[1,1]
        print(f'fp{fp}, tp{tp}')
        return(tp-fp)

mc = keras.callbacks.ModelCheckpoint('model.h5',monitor=tpfp(),mode='max',
                                     save_best_only=True,verbose=1)


model.fit(x_train, y_train, epochs=500, batch_size=100,
          validation_data=(x_test, y_test), callbacks=[tpfp(),mc],
          shuffle=True, verbose=2)

1 个答案:

答案 0 :(得分:0)

您不能将回调作为监视参数的参数传递。

解决问题的一种优雅/自然的方法是在方法@on_epoch_end中修改/添加一些代码行。

def on_epoch_end(self,epoch,logs={}):
        x_test=self.validation_data[0]
        y_test=self.validation_data[1]
        y_pred=self.model.predict(x_test,verbose=0)
        y_pred[y_pred>.6]=1  #change threshold here
        y_pred[y_pred<1] = 0
        cm=metrics.confusion_matrix(y_test,y_pred)
        fp=cm[0,1]
        tp=cm[1,1]
        print(f'fp{fp}, tp{tp}')
        my_custom_value = tp - fp
        logs['my_custom_metric'] = my_custom_value
        return(tp-fp)

现在是您的主菜单:

mc = keras.callbacks.ModelCheckpoint('model.h5',monitor='my_custom_metric',mode='max',
                                     save_best_only=True,verbose=1)

通过在纪元末尾添加“ logs”字典,监视器值便可以访问“ my_custom_metric”的值。