keras使用数据集和灵活的损耗/指标进行编译

时间:2019-02-19 01:02:25

标签: tensorflow keras

我正在使用tf.estimator.Estimator将一大堆代码从tf.keras API移植到tf.data.Dataset,希望与提供的compile / { {1}}。编译的fitloss实参使我感到沮丧。

基本上,我想使用损失函数,该函数以非累加的方式使用多个输出和标签,即我想提供

metrics

我无法将其提供给def custom_loss(all_labels, model_outputs): """ Args: all_labels: all labels in the dataset, as a single tensor, tuple or dict model_outputs: all outputs of model as a single tensor, tuple or dict Returns: single loss tensor to be averaged. """" ... ,因为据我所知,它仅支持每个输出/标签损失的加权总和,并根据相应的标签对每个标签的形状进行假设模型输出。我无法单独创建它并使用compile,因为如果我想让model.add_loss处理数据集迭代,则永远无法显式访问标签张量。我曾考虑过将所有输出和标签放在一起/将它们串联在一起,但是那样我就无法监视多个model.fit

我可以使用metrics编写自己的训练循环,但这迫使我复制已经在model.train_on_batch中实现的行为,例如数据集迭代,回调,验证,分配策略等。


举例来说,我想复制以下估算器。

fit

0 个答案:

没有答案