fit_generator()ValueError:形状不兼容

时间:2020-10-08 09:17:10

标签: python numpy tensorflow keras

我已经创建了一个Seq2Seq模型,当我使用fit()进行训练时,它可以很好地工作。现在,我使用更多数据来训练模型,因此我需要使用fit_generator(),但出现此错误:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
    return step_function(self, iterator)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:796 step_function  **
    outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
    return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:789 run_step  **
    outputs = model.train_step(data)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:749 train_step
    y, y_pred, sample_weight, regularization_losses=self.losses)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/compile_utils.py:204 __call__
    loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:149 __call__
    losses = ag_call(y_true, y_pred)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:253 call  **
    return ag_fn(y_true, y_pred, **self._fn_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
    return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:1535 categorical_crossentropy
    return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
    return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:4687 categorical_crossentropy
    target.shape.assert_is_compatible_with(output.shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py:1134 assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
    
    ValueError: Shapes (None, None) and (None, 1, 17772) are incompatible

生成器:

def model_generator():
    for output in training_decoder_output:
        for i in range(500): # each output element is a chunk of 500 targets
            yield ([training_encoder_input[size*i+i], training_decoder_input[size*i+i]], output[i])

这是我的数据的形状:

print(training_decoder_output[0].shape)
(500, 20, 17772) # and a total of 10 arrays form the total of targets
print(training_encoder_input.shape, training_decoder_input.shape)
(5000, 20) (5000, 20)

这是使用Keras Functional API创建的Keras模型:

def create_model(
        input_length=20,
        output_length=20):

    encoder_input = tf.keras.Input(shape=(input_length,))
    decoder_input = tf.keras.Input(shape=(output_length,))

    encoder = tf.keras.layers.Embedding(original_embedding_matrix.shape[0], original_embedding_dim, weights=[original_embedding_matrix], trainable=False)(encoder_input)
    encoder, h_encoder, u_encoder = tf.keras.layers.LSTM(64, return_state=True)(encoder)

    decoder = tf.keras.layers.Embedding(clone_embedding_matrix.shape[0], clone_embedding_dim, weights=[clone_embedding_matrix], trainable=False)(decoder_input)
    decoder = tf.keras.layers.LSTM(64, return_sequences=True)(decoder, initial_state=[h_encoder, u_encoder])
    decoder = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(clone_vocab_size+1))(decoder)

    model = tf.keras.Model(inputs=[encoder_input, decoder_input], outputs=[decoder])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    return model

model = create_model()

model.fit_generator(model_generator(), epochs=5)
    

0 个答案:

没有答案