我目前正在尝试构建用于文本摘要的seq2seq模型,并且在尝试运行推理模型时遇到错误,我不确定自己做错了什么。我当时认为这可能与关注层有关,因为当我将其删除时,模型将再次开始工作。非常感谢您的帮助和指示。
编码器解码器模型-
encoder_inputs = Input(shape=(MAX_SEQUENCE_LENGTH,))
#embedding layer
enc_emb = Embedding(num_words, EMBEDDING_DIM ,trainable=True)(encoder_inputs)
#encoder lstm 1
encoder_lstm1 = LSTM(latent_dim,return_sequences=True,return_state=True,dropout=0.2)
encoder_output1, state_h1, state_c1 = encoder_lstm1(enc_emb)
#encoder lstm 2
encoder_lstm2 = LSTM(latent_dim,return_sequences=True,return_state=True,dropout=0.2)
encoder_output2, state_h2, state_c2 = encoder_lstm2(encoder_output1)
#encoder lstm 3
encoder_lstm3=LSTM(latent_dim, return_state=True, return_sequences=True,dropout=0.2)
encoder_outputs, state_h, state_c= encoder_lstm3(encoder_output2)
encoder_states = [state_h, state_c]
# Set up the decoder
decoder_inputs = Input(shape=(None,))
#embedding layer
dec_emb_layer = Embedding(num_words, EMBEDDING_DIM,trainable=True)
dec_emb = dec_emb_layer(decoder_inputs)
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True,dropout=0.2)
decoder_outputs,decoder_fwd_state, decoder_back_state = decoder_lstm(dec_emb,initial_state=encoder_states)
attention_layer = tf.keras.layers.Attention()
attention = attention_layer([decoder_outputs, encoder_outputs])
decoder_outputs = tf.keras.layers.concatenate([attention, decoder_outputs])
#dense layer
decoder_dense = TimeDistributed(Dense(num_words, activation='softmax'))
decoder_outputs = decoder_dense(decoder_outputs)
# Define the model
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy',metrics=["accuracy"])
推理模型-
encoder_model = Model(encoder_inputs, encoder_states)
# Decoder setup
# Below tensors will hold the states of the previous time step
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_hidden_state_input = Input(shape=(MAX_SEQUENCE_LENGTH,latent_dim))
# Get the embeddings of the decoder sequence
dec_emb2= dec_emb_layer(decoder_inputs)
# To predict the next word in the sequence, set the initial states to the states from the previous time step
decoder_outputs2, state_h2, state_c2 = decoder_lstm(dec_emb2, initial_state=decoder_states_inputs)
attention=attention_layer([decoder_hidden_state_input, decoder_outputs2])
decoder_outputs = tf.keras.layers.concatenate([attention, decoder_outputs2])
# A dense softmax layer to generate prob dist. over the target vocabulary
decoder_outputs2 = decoder_dense(decoder_outputs)
# Final decoder model
decoder_model = Model(
[decoder_inputs] + [decoder_hidden_state_input,decoder_states_inputs],
[decoder_outputs2] + [state_h2, state_c2])
states_value=encoder_model.predict(encoder_input_data[10].reshape(1,-1))
target_seq = np.zeros((1,1))
target_seq[0, 0] = word_dict['<s>']
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, h, c = decoder_model.predict(
[target_seq] + states_value)
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reversed_dict[sampled_token_index]
decoded_sentence += ' '+sampled_char
if (sampled_char == "</s>" or len(decoded_sentence) > 52):
stop_condition = True
target_seq = np.zeros((1,1))
target_seq[0, 0] = sampled_token_index
states_value = [h, c]
错误消息-
WARNING:tensorflow:Model was constructed with shape (None, 300, 300) for input Tensor("input_105:0", shape=(None, 300, 300), dtype=float32), but it was called on an input with incompatible shape (None, 300).
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-153-9c955bd0353a> in <module>()
6 while not stop_condition:
7 output_tokens, h, c = decoder_model.predict(
----> 8 [target_seq] + states_value)
9 sampled_token_index = np.argmax(output_tokens[0, -1, :])
10 sampled_char = reversed_dict[sampled_token_index]
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
AssertionError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1147 predict_function *
outputs = self.distribute_strategy.run(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1122 predict_step **
return self(x, training=False)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:719 call
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:899 _run_internal_graph
assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
AssertionError: Could not compute output Tensor("time_distributed_16_1/Identity:0", shape=(None, 300, 5004), dtype=float32)