tf keras,自定义损失功能,需要多个网络输出作为输入

时间:2019-09-25 07:11:38

标签: tensorflow keras tf.keras

我最近决定尝试tf2.0,尤其是高级keras API。我也不太熟悉keras,以前主要使用过tensorflow。但是,我找不到使用自定义损失函数的好方法。举一个简单的例子,假设我们有一个定义如下的多输入多输出模型:

# define model
... your network definition
...

model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output_1, output_2, output_3])

让我们假设有两个自定义损失函数:my_loss_1需要output_1, output_2targets_1才能计算损失。同样,my_loss_2需要output_2, output_3targets_2。我的问题是:

  1. 是否可以通过坚持使用高级API来做到这一点?
  2. 如何将训练数据传递给模型?如果我将数据集构建为
train_ds = tf.data.Dataset.zip((input_ds_1, input_ds_2), (target_ds_1, target_ds_2))

并将其传递给model.fit,然后模型将假定使用loss_1output_1计算target_ds_1,这不是我想要的。

  1. 我什至不知道将tf.data与keras一起使用时的数据对应规则如何,在任何地方都没有说明。我的猜测是,训练数据集应为(inputs, targets)的元组。如果您的网络接受多个输入,则输入是(input_1, input_2, ...)的另一个元组,与目标相同。这是正确的吗?

  2. 是否可以重用传递给模型的数据?假设现在my_loss_1my_loss_2都需要target_1,是否需要更改train_ds以包含重复项?

train_ds = tf.data.Dataset.zip((input_ds_1, input_ds_2), (target_ds_1, (target_ds_1, target_ds_2)))

0 个答案:

没有答案