使用tf.data.Dataset将数据输入具有多个输入的Keras模型

时间:2019-12-10 02:36:41

标签: python tensorflow keras tf.keras

我有一个使用tf.keras构建的模型,该模型具有两个输入(“ input_1”和“ input_2”),这些输入被馈送到网络中的不同分支。最终输出是单个输出。由于我的数据量很大,因此我想使用tf.data来处理输入管道。

我尝试了此处提供的解决方案: https://stackoverflow.com/a/52661189/5956578

但是,当我使用该解决方案运行model.fit时,出现错误:

Invalid argument: You must feed a value for the placeholder tensor 'input_2' with dtype float and shape [20,375,1242,1]
[[{{node input_2}}]]

我意识到这是因为数据集的输出是由“ input_1”和“ input_2”组成的字典,而model.fit无法正确地将其输入到相应的输入中。

任何帮助解决此问题的方法,或训练多输入tf.keras网络的替代解决方案,将不胜感激。

编辑:根据上面链接中的解决方案,这是我代码的相关部分:

...

input_1 = tf.keras.layers.Input(name='input_1', batch_size=batch_size, shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
input_2 = tf.keras.layers.Input(name='input_2', batch_size=batch_size, shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 1))
output = KerasFunctionalAPINet(input_1, input_2)
model = tf.keras.models.Model(inputs=[input_1, input_2], outputs=output, name='Network')

...

def train_generator():
    for i in range(100):
        # Code to get images "source_1", "source_2" and labels "labels" from another python module
        yield {"input_1": source_1, "input_2": source_2}, labels
train_set = tf.data.Dataset.from_generator(train_generator, output_types=({"input_1": tf.float32, "input_2": tf.float32}, tf.float32), output_shapes=({"input_1": (IMAGE_HEIGHT, IMAGE_WIDTH, 3), "input_2": (IMAGE_HEIGHT, IMAGE_WIDTH, 1)}, (6, 1)))
train_set = train_set.batch(batch_size*2, drop_remainder=True)

def test_generator():
    for i in range(100):
        # Code to get images "source_1", "source_2" and labels "labels" from another python module
        yield {"input_1": source_1, "input_2": source_2}, labels
test_set = tf.data.Dataset.from_generator(test_generator, output_types=({"input_1": tf.float32, "input_2": tf.float32}, tf.float32), output_shapes=({"input_1": (IMAGE_HEIGHT, IMAGE_WIDTH, 3), "input_2": (IMAGE_HEIGHT, IMAGE_WIDTH, 1)}, (6, 1)))
train_set = train_set.batch(batch_size*2, drop_remainder=True)

...

model.fit(
    train_set,
    steps_per_epoch=(int)(no_of_samples/batch_size),
    epochs=epochs,
    validation_data=test_set,
    validation_steps=(int)(no_of_samples/batch_size),
    shuffle=True,
    verbose=1
)

1 个答案:

答案 0 :(得分:0)

yield [source_1, source_2], labels