如何在sklearn中交叉验证keras包装器估算器时访问历史对象?

时间:2018-04-13 11:32:48

标签: scikit-learn keras

我希望在交叉验证中看到每个拆分的丢失/错误进展。 keras.wrappers.scikit_learn.KerasClassifier的fit方法返回一个包含我想要的数据的history对象,但在sklearn.model_selection.cross_validate变体方法中运行它时无法访问它。

如何在每次拆分中访问每个纪元的历史对象?

1 个答案:

答案 0 :(得分:0)

您可能可以使用CSVLogger回调来访问完整的历史记录。设置CSVLogger回调很容易,它将以您指定的任何文件名记录{epoch,acc,loss,val_acc,val_loss}。

在我的代码中,我做类似的事情:

keras_classifier.fit(X, y, groups=None, 
    callbacks=[keras.callbacks.CSVLogger(filename, append=True)])

设置append=True应该确保所有拆分的所有数据都包含在文件中。

注意事项:

  • 我不确定这是否可以与n_jobs=-1一起使用(用于在多个处理器上分配处理),但是如果您运行单线程,它应该可以工作。
  • 确保在运行分类器之前(或在初始化过程中)删除文件,以避免无限期附加到该文件。