在tensorflow 2中不再支持提取和分配。遵循https://stackoverflow.com/a/47081613/9949099中提供的答案,可以在自定义keras回调中访问tf 1.x中的批处理结果。 在tf.keras和tf 2.0中,不支持执行急切的提取,因此为tf 1.x提供的解决方案不起作用。 有没有办法在tf.keras自定义回调的on_batch_end回调中获取y_true和y_pred?
我试图像下面那样修改在tf.1中工作的答案
from tf.keras.callbacks import Callback
class CollectOutputAndTarget(Callback):
def __init__(self):
super(CollectOutputAndTarget, self).__init__()
self.targets = [] # collect y_true batches
self.outputs = [] # collect y_pred batches
def on_batch_end(self, batch, logs=None):
# evaluate the variables and save them into lists
# How to change the following 2 lines so that in tf.2 eager execution collect the batch results
self.targets.append(K.eval(self.model._targets[0]))
self.outputs.append(K.eval(self.model.outputs[0]))
当我运行上面的代码时,代码失败,显然无法访问self.model._targets [0]或self.model.outputs [0]中的数据