我正在使用tf.estimator.Estimator
将一大堆代码从tf.keras
API移植到tf.data.Dataset
,希望与提供的compile
/ { {1}}。编译的fit
和loss
实参使我感到沮丧。
基本上,我想使用损失函数,该函数以非累加的方式使用多个输出和标签,即我想提供
metrics
我无法将其提供给def custom_loss(all_labels, model_outputs):
"""
Args:
all_labels: all labels in the dataset, as a single tensor, tuple or dict
model_outputs: all outputs of model as a single tensor, tuple or dict
Returns:
single loss tensor to be averaged.
""""
...
,因为据我所知,它仅支持每个输出/标签损失的加权总和,并根据相应的标签对每个标签的形状进行假设模型输出。我无法单独创建它并使用compile
,因为如果我想让model.add_loss
处理数据集迭代,则永远无法显式访问标签张量。我曾考虑过将所有输出和标签放在一起/将它们串联在一起,但是那样我就无法监视多个model.fit
。
我可以使用metrics
编写自己的训练循环,但这迫使我复制已经在model.train_on_batch
中实现的行为,例如数据集迭代,回调,验证,分配策略等。
举例来说,我想复制以下估算器。
fit