如何在Keras中获取历史回调指标?

时间:2019-02-09 22:32:31

标签: python keras

如何检索回调指标的历史记录?我有一个类Metrics,并在Keras模型的fit函数中使用它,如下所示callbacks=[model_metrics]

这是类Metricsfit函数的完整代码。

class Metrics(Callback):

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_bal_accs = []

    def on_epoch_end(self, epoch, logs={}):
        val_predict = np.argmax((np.asarray(self.model.predict(self.validation_data[0]))).round(), axis=1)
        val_targ = np.argmax(self.validation_data[1], axis=1)
        _val_f1 = metrics.f1_score(val_targ, val_predict, average='weighted')
        _val_bal_acc = metrics.balanced_accuracy_score(val_targ, val_predict)    
        self.val_f1s.append(_val_f1)
        self.val_bal_accs.append(_val_bal_acc)
        print(" — val_f1: {:f} — val_bal_acc: {:f}".format(_val_f1, _val_bal_acc))
        return

model_metrics = Metrics()

history = model.fit(np.array(X_train), y_train, 
                    validation_data=(np.array(X_test), y_test),
                    epochs=5,
                    batch_size=2,
                    callbacks=[model_metrics],
                    shuffle=False,
                    verbose=1)

如何获得historyval_f1中的val_bal_acc?现在我只能访问lossval_lossaccval_acc

print(history.history.keys())

1 个答案:

答案 0 :(得分:1)

要与keras历史API进行交互,您需要传递metrics而不是callbacks的参数。

在当前状态下,您的val_f1val_bal_acc不会存储在历史记录对象中,而是会存储在model_metrics对象中。

您可以像这样访问它们:

model_metrics.val_f1s

与访问任何对象的属性相同。

最后,如果您确实想创建自定义指标并希望从历史记录中访问它,则需要定义一个自定义指标(作为函数),然后将其传递到metrics中的model.compile kwarg中。这样做是这样的:

def my_metric(y_true y_pred):
    return y_true # just a dummy return value

# assume that the model is defined somewhere
model.compile(loss=..., optimizer=..., metrics = [my_metric]

然后您将能够在不适合的历史对象中找到val_my_metric