我最近决定尝试tf2.0,尤其是高级keras API。我也不太熟悉keras,以前主要使用过tensorflow。但是,我找不到使用自定义损失函数的好方法。举一个简单的例子,假设我们有一个定义如下的多输入多输出模型:
# define model
... your network definition
...
model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output_1, output_2, output_3])
让我们假设有两个自定义损失函数:my_loss_1
需要output_1, output_2
和targets_1
才能计算损失。同样,my_loss_2
需要output_2, output_3
和targets_2
。我的问题是:
train_ds = tf.data.Dataset.zip((input_ds_1, input_ds_2), (target_ds_1, target_ds_2))
并将其传递给model.fit
,然后模型将假定使用loss_1
和output_1
计算target_ds_1
,这不是我想要的。
我什至不知道将tf.data与keras一起使用时的数据对应规则如何,在任何地方都没有说明。我的猜测是,训练数据集应为(inputs, targets)
的元组。如果您的网络接受多个输入,则输入是(input_1, input_2, ...)
的另一个元组,与目标相同。这是正确的吗?
是否可以重用传递给模型的数据?假设现在my_loss_1
和my_loss_2
都需要target_1
,是否需要更改train_ds
以包含重复项?
train_ds = tf.data.Dataset.zip((input_ds_1, input_ds_2), (target_ds_1, (target_ds_1, target_ds_2)))