创建tf.keras回调以在tf 2.0培训期间保存每个批次的模型预测和目标

时间:2019-10-04 03:27:27

标签: python tensorflow tf.keras eager-execution

在tensorflow 2中不再支持提取和分配。遵循https://stackoverflow.com/a/47081613/9949099中提供的答案,可以在自定义keras回调中访问tf 1.x中的批处理结果。 在tf.keras和tf 2.0中,不支持执行急切的提取,因此为tf 1.x提供的解决方案不起作用。 有没有办法在tf.keras自定义回调的on_batch_end回调中获取y_true和y_pred?

我试图像下面那样修改在tf.1中工作的答案

from tf.keras.callbacks import Callback

class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.targets = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches

    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        # How to change the following 2 lines so that in tf.2 eager execution collect the batch results
        self.targets.append(K.eval(self.model._targets[0]))
        self.outputs.append(K.eval(self.model.outputs[0]))

当我运行上面的代码时,代码失败,显然无法访问self.model._targets [0]或self.model.outputs [0]中的数据

0 个答案:

没有答案