如何通过EarlyStopping或ModelCheckpoint从回调使用自定义指标?

时间:2019-10-15 09:26:27

标签: keras callback

我想从另一个回调(如EarlyStopping或ModelCheckpoint)中的回调中使用自定义指标。但是我需要以某种方式保存/存储/记录该自定义指标,以便其他回调可以访问此指标?

我有:

class Metrics(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):

        self.precision = []
        self.f1s = []
        self.prc=0
        self.f1s=0

    def on_epoch_end(self, epoch, logs={}):
        score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))
        predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))
        targ = self.validation_data[2]

        predict = (predict < 0.5).astype(np.float)


        self.prc=sklm.precision_score(targ, predict)
        self.f1s=sklm.f1_score(targ, predict)
        self.precision.append(prc)
        self.f1s.append(f1s)

        print("— val_f1: %f — val_precision: %f" %(self.f1s, self.prc))
        return

现在

metrics = Metrics()

es = EarlyStopping(monitor=metrics.prc, mode='max', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)

model.compile(loss=contrastive_loss, optimizer=adam)
model.fit([train_sen1, train_sen2], train_labels,
          batch_size=512,
          epochs=20,callbacks=[metrics,es],
          validation_data=([dev_sen1, dev_sen2], dev_labels))

不起作用,因为Earlystopping不了解自定义精度指标?

有人知道这个回调的日志语句吗?我可以在那里保存指标吗?

2 个答案:

答案 0 :(得分:0)

要了解这里到底发生了什么,您必须检查github上EarlyStopping和ModelCheckpoint类的源代码。您可以找到它here

您的代码中的问题是您没有更新“ on_epoch_end”函数中的“ logs”字典。在该词典中,EarlyStopping和ModelCheckpoint类将查找您定义为“监视器”的内容。

因此,在您的情况下,如果要使用精度分数作为监视器,则代码应如下所示:

class Metrics(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):

        self.precision = []
        self.f1scores = []
        self.prc=0
        self.f1s=0

    def on_epoch_end(self, epoch, logs={}):
        score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))
        predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))
        targ = self.validation_data[2]

        predict = (predict < 0.5).astype(np.float)


        self.prc=sklm.precision_score(targ, predict)
        self.f1s=sklm.f1_score(targ, predict)
        self.precision.append(prc)
        self.f1scores.append(f1s)

        #Here is where I update the logs dictionary:
        logs["prc"]=self.prc
        logs["f1s"]=self.f1s

        print("— val_f1: %f — val_precision: %f" %(self.f1s, self.prc))

然后您可以在CheckpointModel和EarlyStopping中调用这些自定义指标。但是,请确保将这些回调以正确的顺序放在fit_generator中:应首先放置指标,否则运行EarlyStopping时不会更新日志。

metrics = Metrics()

es = EarlyStopping(monitor="prc", mode='max', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)

model.compile(loss=contrastive_loss, optimizer=adam)
model.fit([train_sen1, train_sen2], train_labels,
          batch_size=512,
          epochs=20,callbacks=[metrics,es],
          validation_data=([dev_sen1, dev_sen2], dev_labels))

答案 1 :(得分:0)

zealous_nightingale 的回答适用于 <NestedScrollView xmlns:android="http://schemas.android.com/apk/res/android" android:id="@+id/mainScrollView" android:fillViewport="true" android:layout_width="match_parent" android:layout_height="match_parent"> <HorizontalScrollView android:id="@+id/horizontalScroll" android:layout_width="match_parent" android:layout_height="wrap_content" /> </NestedScrollView> 回调,但是,对于 EarlyStopping 回调,您可能还需要将 ModelCheckpoint 标志设置为 _supports_tf_logs 以便更新 False字典传递给回调:

log