我创建了一个自定义的Keras回调。我知道我可以使用类属性(更多here)访问验证数据
Keras预测方法会遍历所有数据集,因此我的预测变量包含(批处理大小*步骤)样本。尽管如此,目标仅包含一批数据。
如何遍历整个验证目标?
class DummyCallback(tf.keras.callbacks.Callback):
def __init__(self, steps):
self.steps = steps
super().__init__()
def on_epoch_end(self, epoch, logs=None):
data = self.validation_data[0] # input data
target = self.validation_data[1] # labels
predictions = self.model.predict(data, steps=self.steps)
# predictions shape: [batch_size*steps, ...]
# target shape: [batch_size, ...]