使用生成器训练 Tensorflow 模型

时间:2021-03-13 12:21:59

标签: python tensorflow keras tensorflow2.0

我使用 keras 函数式 api 创建了一个具有多个输入的模型,并且在使用生成器训练模型时遇到了不同的问题。

模型:模型有两个输入,返回一个输出

inp1 = tf.keras.Input(shape=(1, 9,))
inp2 = tf.keras.Input(shape=(1, 9,))

x = tf.keras.layers.Concatenate(axis=1)([inp1, inp2])
x = tf.keras.layers.LSTM(128, input_shape=(2, 3,), activation='relu')(x)
x = tf.keras.layers.RepeatVector(65)(x)
x = tf.keras.layers.Dense(32, activation='relu')(x)
output = tf.keras.layers.Dense(5, activation='softmax')(x)

model = tf.keras.models.Model(inputs=[inp1, inp2], outputs=output)

生成器:

def data_generator():
    # Some hard computations here
    for input1,input2,target in zip(inputs1, inputs2, target):
        yield ([input1,input2],target)

gen = data_generator()

只是为了做一个comprobation,我运行这个:

for i in gen:
    bar = i
    break

print(bar)
#   ([array([[0, 0, 0, 0, 0, 0, 1, 2, 4]]), array([[0, 0, 0, 0, 0, 0, 1, 2, 4]])],
      array([[1., 0., 0., 0., 0.],
             [1., 0., 0., 0., 0.],
             [1., 0., 0., 0., 0.],
             [1., 0., 0., 0., 0.],
                   ....
             [0., 0., 0., 0., 1.]], dtype=float32))


print(bar[0][0].shape, bar[0][1].shape, bar[1].shape)
# ((1, 9), (1, 9), (65, 5))

如你所见,生成器中的每一项都有 ([input1, input2], output) 结构。虽然我认为这是应该的,但似乎有些错误,因为当我运行 fit() 时:

model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['categorical_accuracy'])

model.fit(gen,
          epochs=5
         )

我收到此警告:

WARNING:tensorflow:Model was constructed with shape (None, 1, 9) for input KerasTensor(type_spec=TensorSpec(shape=(None, 1, 9), dtype=tf.float32, name='input_17'), name='input_17', description="created by layer 'input_17'"), but it was called on an input with incompatible shape (None, None).
WARNING:tensorflow:Model was constructed with shape (None, 1, 9) for input KerasTensor(type_spec=TensorSpec(shape=(None, 1, 9), dtype=tf.float32, name='input_18'), name='input_18', description="created by layer 'input_18'"), but it was called on an input with incompatible shape (None, None).

然后是完整的回溯:

值错误:在用户代码中:

c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py:805 train_function  *
    return step_function(self, iterator)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py:795 step_function  **
    outputs = model.distribute_strategy.run(run_step, args=(data,))
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
    return fn(*args, **kwargs)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py:788 run_step  **
    outputs = model.train_step(data)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py:754 train_step
    y_pred = self(x, training=True)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:1012 __call__
    outputs = call_fn(inputs, *args, **kwargs)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\functional.py:424 call
    return self._run_internal_graph(
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\functional.py:560 _run_internal_graph
    outputs = node.layer(*args, **kwargs)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\layers\recurrent.py:660 __call__
    return super(RNN, self).__call__(inputs, **kwargs)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:998 __call__
    input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
c:\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:219 assert_input_compatibility
    raise ValueError('Input ' + str(input_index) + ' of layer ' +

ValueError: Input 0 of layer lstm_16 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, None)

我理解这个错误,但我不明白为什么在我使用生成器时会引发它。我应该怎么做才能解决这个问题?

0 个答案:

没有答案
相关问题