我正在用Keras训练序列到序列模型。模型训练使用bi-LSTM进行编码,使用2x bi-LSTM进行解码。在模型训练期间,这就像一种魅力。
在推理过程中加载训练后的模型时,输出为纯数据垃圾。我想知道我在做什么错。我有一个相同模型的版本,在推理期间仅使用一个双解码器(而不是两个)就可以很好地工作。我很确定我了解该原理。我还看到了seq-2-seq的Keras示例;我错过了一些东西,但我不知道是什么。有人看到我使用/初始化图层的明显缺陷了吗?
我很困惑,因为推断适用于较小的模型,但不适用于较大的模型。在模型训练期间调用predict()可以清楚地表明,更复杂的模型主要起作用并给出有效结果。
这是我的模型接线推理代码。有人看到任何错误/不正确的东西(用于两个解码器层)吗?
######################
### ENCODER MODEL ##
######################
encoder_inputs = model.input[0] # input_1
fwd_encoder_lstm = model.layers[2]
fwd_enc_out, fwd_enc_hid, fwd_enc_cell = fwd_encoder_lstm(encoder_inputs)
bwd_encoder_lstm = model.layers[3]
bwd_enc_out, bwd_enc_hid, bwd_enc_cell = bwd_encoder_lstm(encoder_inputs)
encoding_results = Concatenate()([fwd_enc_out, bwd_enc_out])
encoder_model = Model(inputs=encoder_inputs,
outputs=[encoding_results, fwd_enc_hid, fwd_enc_cell, bwd_enc_hid, bwd_enc_cell])
######################
### DECODER MODEL ##
######################
decoder_inputs = model.input[1] # input_2
encoded_seq = Input(shape=(None, latent_dim*2), name='input__encoder_out')
in_fwd_hidden_1 = Input(shape=(latent_dim,), name='input__fwd_hidden_init_1')
in_fwd_cell_1 = Input(shape=(latent_dim,), name='input__fwd_cell_init_1')
in_bwd_hidden_1 = Input(shape=(latent_dim,), name='input__bwd_hidden_init_1')
in_bwd_cell_1 = Input(shape=(latent_dim,), name='input__bwd_cell_init_1')
fwd_init = [in_fwd_hidden_1, in_fwd_cell_1]
bwd_init = [in_bwd_hidden_1, in_bwd_cell_1]
# decoder lstm is initialized with encoder hidden/cell state
fwd_decoder_lstm = model.layers[4]
fwd_dec_out_1, fwd_dec_hid_1, fwd_dec_cell_1 = fwd_decoder_lstm(decoder_inputs,
initial_state=fwd_init)
bwd_decoder_lstm = model.layers[5]
bwd_dec_out_1, bwd_dec_hid_1, bwd_dec_cell_1 = bwd_decoder_lstm(decoder_inputs,
initial_state=bwd_init)
# attention is computed from the encoder outputs and the decoder outputs
decoder_one_output = Concatenate()([fwd_dec_out_1, bwd_dec_out_1])
fwd_decoder_lstm = model.layers[7]
fwd_dec_out_2, fwd_dec_hid_2, fwd_dec_cell_2 = fwd_decoder_lstm(decoder_one_output,
initial_state=fwd_init)
bwd_decoder_lstm = model.layers[8]
bwd_dec_out_2, bwd_dec_hid_2, bwd_dec_cell_2 = bwd_decoder_lstm(decoder_one_output,
initial_state=bwd_init)
decoded_seq = Concatenate()([fwd_dec_out_2, bwd_dec_out_2])
attention = Dot([2,2])([decoded_seq, encoded_seq])
attention = Activation('softmax')(attention)
context = Dot([2,1])([attention, encoded_seq])
dec_combined = Concatenate()([context, decoded_seq])
timeDistributedLayer = model.layers[15]
final_prediction = timeDistributedLayer(dec_combined)
encoder_parameters = [encoded_seq, in_fwd_hidden_1, in_fwd_cell_1, in_bwd_hidden_1, in_bwd_cell_1]
decoder_parameters = [fwd_dec_hid_1, fwd_dec_cell_1, bwd_dec_hid_1, bwd_dec_cell_1]
decoder_model = Model(
inputs=[decoder_inputs] + encoder_parameters,
outputs=[final_prediction] + decoder_parameters)
这里是推理循环,我在其中馈入了上一个时间步的状态,还给出了下一个的起始值:
def decode_sequence(input_seq):
# Encode the input as state vectors.
encoded_seq, fwd_hid, fwd_cell, bwd_hid, bwd_cell = encoder_model.predict(input_seq)
states_value = [encoded_seq, fwd_hid, fwd_cell, bwd_hid, bwd_cell]
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index[BOS]] = 1.
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, next_fwd_hid, next_fwd_cell, next_bwd_hid, next_bwd_cell = decoder_model.predict([target_seq] + states_value)
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if (sampled_char == EOS or len(decoded_sentence) > MAX_LEN):
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.
# Update states
states_value = [encoded_seq,
next_fwd_hid, next_fwd_cell,
next_bwd_hid, next_bwd_cell]
# decoded sequence without EOS
return decoded_sentence