输入多个数据集到张量流模型

时间:2020-09-20 13:58:28

标签: python tensorflow keras

嗨,我正在尝试在一个模型中输入多个数据集。这是我的问题的一个示例,但是在我的情况下,我的一个模型有2个输入参数,而另一个有1个。我遇到的错误是:

Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>", "<class 'tensorflow.python.data.ops.dataset_ops.TakeDataset'>"}), <class 'NoneType'>

代码:

import tensorflow as tf

# Create first model
model1 = tf.keras.Sequential()
model1.add(tf.keras.layers.Dense(1))
model1.compile()
model1.build([None,3])

# Create second model
model2 = tf.keras.Sequential()
model2.add(tf.keras.layers.Dense(1))
model2.compile()
model2.build([None,3])


# Concatenate
fusion_model = tf.keras.layers.Concatenate()([model1.output, model2.output])
t = tf.keras.layers.Dense(1, activation='tanh')(fusion_model)
model = tf.keras.models.Model(inputs=[model1.input, model2.input], outputs=t)
model.compile()

#Datasets
ds1 = tf.data.Dataset.from_tensors(([1,2,3],1))
ds2 = tf.data.Dataset.from_tensors(([1,2,3], 2))

print(ds1)
print(ds2)
# Fit
model.fit([ds1,ds2])

运行此示例代码将产生以下结果:

Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.TensorDataset'>"}), <class 'NoneType'>

我需要使用数据集模块,因为它们提供了数据的内置延迟加载。

1 个答案:

答案 0 :(得分:1)

如评论中所述,TensorFlow模型中的TensorFlow .fit函数不支持数据集列表。

如果您真的想使用数据集,则可以使用字典作为输入,并命名输入层以匹配字典。

这是您的操作方式:

model1 = tf.keras.Sequential(name="layer_1")
model2 = tf.keras.Sequential(name="layer_2")
model.summary()

ds1 = tf.data.Dataset.from_tensors(({"layer_1": [[1,2,3]],
                                     "layer_2": [[1,2,3]]}, [[2]]))

model.fit(ds1)

一个更简单的选择是只使用张量而不是数据集。 .fit支持张量列表作为输入,因此只需使用它即可。

model = tf.keras.models.Model(inputs=[model1.input, model2.input], outputs=t)
model.compile(loss='mse')

model.summary()

a = tf.constant([[1, 2, 3]])
b = tf.constant([[1, 2, 3]])

c = tf.constant([[1]])

model.fit([a, b], c)