Python / Keras - 如何访问每个时代预测?

时间:2016-04-26 12:19:39

标签: python keras

我正在使用Keras来预测时间序列。作为标准,我使用了20个时代。 我想知道我的神经网络为20个时期中的每一个预测了什么。

通过使用model.predict我在所有时期中只得到一个预测(不确定Keras如何选择它)。我想要所有预测,或者至少是10个最佳预测。

有人知道如何帮助我吗?

2 个答案:

答案 0 :(得分:9)

我认为这里有点混乱。

仅在训练神经网络时使用纪元,因此当训练停止时(在这种情况下,在第20纪元之后),则权重对应于在上一纪元计算的权重。

Keras在每个纪元后的训练期间打印验证集上的当前损失值。如果未保存每个纪元后的权重,则它们将丢失。您可以使用ModelCheckpoint回调为每个纪元保存权重,然后在模型上使用 load_weights 加载它们。

您可以在每个训练时期之后计算您的预测,方法是通过继承Callback并在 on_epoch_end 函数内的模型上调用预测来实现适当的回调。

然后使用它,实例化你的回调,制作一个列表并将其用作 model.fit 的关键字参数回调。

答案 1 :(得分:2)

以下代码将完成所需的工作:

import tensorflow as tf
import keras

# define your custom callback for prediction
class PredictionCallback(tf.keras.callbacks.Callback):    
  def on_epoch_end(self, epoch, logs={}):
    y_pred = self.model.predict(self.validation_data[0])
    print('prediction: {} at epoch: {}'.format(y_pred, epoch))

# ...

# register the callback before training starts
model.fit(X_train, y_train, batch_size=32, epochs=25, 
          validation_data=(X_valid, y_valid), 
          callbacks=[PredictionCallback()])