遍历keras回调上的验证目标

时间:2019-07-30 19:40:43

标签: python tensorflow keras

我创建了一个自定义的Keras回调。我知道我可以使用类属性(更多here)访问验证数据

Keras预测方法会遍历所有数据集,因此我的预测变量包含(批处理大小*步骤)样本。尽管如此,目标仅包含一批数据。

如何遍历整个验证目标?

class DummyCallback(tf.keras.callbacks.Callback):

def __init__(self, steps):
    self.steps = steps
    super().__init__()

def on_epoch_end(self, epoch, logs=None):
    data = self.validation_data[0] # input data
    target = self.validation_data[1] # labels
    predictions = self.model.predict(data, steps=self.steps)
    # predictions shape: [batch_size*steps, ...]
    # target shape: [batch_size, ...]

0 个答案:

没有答案