我有一个使用tf.keras构建的模型,该模型具有两个输入(“ input_1”和“ input_2”),这些输入被馈送到网络中的不同分支。最终输出是单个输出。由于我的数据量很大,因此我想使用tf.data来处理输入管道。
我尝试了此处提供的解决方案: https://stackoverflow.com/a/52661189/5956578
但是,当我使用该解决方案运行model.fit时,出现错误:
Invalid argument: You must feed a value for the placeholder tensor 'input_2' with dtype float and shape [20,375,1242,1]
[[{{node input_2}}]]
我意识到这是因为数据集的输出是由“ input_1”和“ input_2”组成的字典,而model.fit无法正确地将其输入到相应的输入中。
任何帮助解决此问题的方法,或训练多输入tf.keras网络的替代解决方案,将不胜感激。
编辑:根据上面链接中的解决方案,这是我代码的相关部分:
...
input_1 = tf.keras.layers.Input(name='input_1', batch_size=batch_size, shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
input_2 = tf.keras.layers.Input(name='input_2', batch_size=batch_size, shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 1))
output = KerasFunctionalAPINet(input_1, input_2)
model = tf.keras.models.Model(inputs=[input_1, input_2], outputs=output, name='Network')
...
def train_generator():
for i in range(100):
# Code to get images "source_1", "source_2" and labels "labels" from another python module
yield {"input_1": source_1, "input_2": source_2}, labels
train_set = tf.data.Dataset.from_generator(train_generator, output_types=({"input_1": tf.float32, "input_2": tf.float32}, tf.float32), output_shapes=({"input_1": (IMAGE_HEIGHT, IMAGE_WIDTH, 3), "input_2": (IMAGE_HEIGHT, IMAGE_WIDTH, 1)}, (6, 1)))
train_set = train_set.batch(batch_size*2, drop_remainder=True)
def test_generator():
for i in range(100):
# Code to get images "source_1", "source_2" and labels "labels" from another python module
yield {"input_1": source_1, "input_2": source_2}, labels
test_set = tf.data.Dataset.from_generator(test_generator, output_types=({"input_1": tf.float32, "input_2": tf.float32}, tf.float32), output_shapes=({"input_1": (IMAGE_HEIGHT, IMAGE_WIDTH, 3), "input_2": (IMAGE_HEIGHT, IMAGE_WIDTH, 1)}, (6, 1)))
train_set = train_set.batch(batch_size*2, drop_remainder=True)
...
model.fit(
train_set,
steps_per_epoch=(int)(no_of_samples/batch_size),
epochs=epochs,
validation_data=test_set,
validation_steps=(int)(no_of_samples/batch_size),
shuffle=True,
verbose=1
)
答案 0 :(得分:0)
yield [source_1, source_2], labels