TF-Keras-用于多输入功能API模型的Dataset.from_generator

时间:2020-08-13 21:09:29

标签: python tensorflow keras tensorflow-datasets

我有一个生成器,生成三个变量。前两个变量是两输入Keras模型(功能性API)的两个输入。我正在使用TF-Dataset来填充我的模型。代码如下:


train_dataset = tf.data.Dataset.from_generator(generator=make_generator_train,
                                                   args=[train_x_paths, train_y_int],
                                                   output_types=(tf.tuple((tf.float16, tf.float16)), tf.int8),
                                                   output_shapes=(tf.TensorShape([2]),
                                                                  tf.TensorShape([1]))).batch(batch_size=batch_size)

我得到一个TypeError

TypeError:如果浅层结构是一个序列,则输入也必须是一个序列。输入具有类型:

2 个答案:

答案 0 :(得分:0)

尝试这样:

train_dataset = tf.data.Dataset.from_generator(
        generator=make_generator_train,
       args=[train_x_paths, train_y_int],
       output_types=(tf.float16, tf.int8)
).batch(batch_size=batch_size)

大多数时候,您不需要指定output_shapes。它在运行时决定。此外,您只需要在output_types中指定输出张量的整体dtype。不是每个张量维的dtype。

答案 1 :(得分:0)

解决方案:生成器应为输入和输出按原样生成字典。