Chollet在Keras博客上发表了Kera对Seq2Seq模型的介绍:here。我不明白的一件事是GRU seq2seq
模型的推理模型是什么。下面给出了他创建编码器 - 解码器的代码,但是,他没有说明如何相应地更改推理代码。
# encoder decoder model given by fchollet
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = GRU(latent_dim, return_state=True)
encoder_outputs, state_h = encoder(encoder_inputs)
decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_gru = GRU(latent_dim, return_sequences=True)
decoder_outputs = decoder_gru(decoder_inputs, initial_state=state_h)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=["accuracy"])
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
# inference model I tried making
encoder_model = Model(encoder_inputs, state_h)
decoder_state_input_h = Input(shape=(LATENT_DIMENSIONS,))
decoder_states_inputs = [decoder_state_input_h]
decoder_outputs, state_h, state_c = decoder_gru(
decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)
reverse_input_char_index = dict(
(i, char) for char, i in source_token_index.items())
reverse_target_char_index = dict(
(i, char) for char, i in target_token_index.items())
训练部分工作正常,但推理模型给出错误
TypeError:' Tensor'对象不可迭代。
谢谢, Soumil。