推理期间Keras的序列间中断

时间:2019-03-20 10:17:58

标签: tensorflow keras sequence inference seq2seq

我正在用Keras训练序列到序列模型。模型训练使用bi-LSTM进行编码,使用2x bi-LSTM进行解码。在模型训练期间,这就像一种魅力。

在推理过程中加载训练后的模型时,输出为纯数据垃圾。我想知道我在做什么错。我有一个相同模型的版本,在推理期间仅使用一个双解码器(而不是两个)就可以很好地工作。我很确定我了解该原理。我还看到了seq-2-seq的Keras示例;我错过了一些东西,但我不知道是什么。有人看到我使用/初始化图层的明显缺陷了吗?

我很困惑,因为推断适用于较小的模型,但不适用于较大的模型。在模型训练期间调用predict()可以清楚地表明,更复杂的模型主要起作用并给出有效结果。

这是我的模型接线推理代码。有人看到任何错误/不正确的东西(用于两个解码器层)吗?

######################
###  ENCODER MODEL  ##
######################
encoder_inputs = model.input[0]   # input_1
fwd_encoder_lstm  = model.layers[2]
fwd_enc_out, fwd_enc_hid, fwd_enc_cell = fwd_encoder_lstm(encoder_inputs)
bwd_encoder_lstm  = model.layers[3]
bwd_enc_out, bwd_enc_hid, bwd_enc_cell = bwd_encoder_lstm(encoder_inputs)

encoding_results = Concatenate()([fwd_enc_out, bwd_enc_out])
encoder_model = Model(inputs=encoder_inputs, 
                      outputs=[encoding_results, fwd_enc_hid, fwd_enc_cell, bwd_enc_hid, bwd_enc_cell])

######################
###  DECODER MODEL  ##
######################
decoder_inputs = model.input[1]   # input_2
encoded_seq = Input(shape=(None, latent_dim*2), name='input__encoder_out')
in_fwd_hidden_1 = Input(shape=(latent_dim,), name='input__fwd_hidden_init_1')
in_fwd_cell_1 = Input(shape=(latent_dim,), name='input__fwd_cell_init_1')
in_bwd_hidden_1 = Input(shape=(latent_dim,), name='input__bwd_hidden_init_1')
in_bwd_cell_1 = Input(shape=(latent_dim,), name='input__bwd_cell_init_1')

fwd_init = [in_fwd_hidden_1, in_fwd_cell_1]
bwd_init = [in_bwd_hidden_1, in_bwd_cell_1]

# decoder lstm is initialized with encoder hidden/cell state
fwd_decoder_lstm = model.layers[4]
fwd_dec_out_1, fwd_dec_hid_1, fwd_dec_cell_1 = fwd_decoder_lstm(decoder_inputs, 
                                                                initial_state=fwd_init)
bwd_decoder_lstm = model.layers[5]
bwd_dec_out_1, bwd_dec_hid_1, bwd_dec_cell_1 = bwd_decoder_lstm(decoder_inputs, 
                                                                initial_state=bwd_init)

# attention is computed from the encoder outputs and the decoder outputs
decoder_one_output = Concatenate()([fwd_dec_out_1, bwd_dec_out_1])

fwd_decoder_lstm = model.layers[7]
fwd_dec_out_2, fwd_dec_hid_2, fwd_dec_cell_2 = fwd_decoder_lstm(decoder_one_output, 
                                                                initial_state=fwd_init)
bwd_decoder_lstm = model.layers[8]
bwd_dec_out_2, bwd_dec_hid_2, bwd_dec_cell_2 = bwd_decoder_lstm(decoder_one_output,
                                                                initial_state=bwd_init)

decoded_seq = Concatenate()([fwd_dec_out_2, bwd_dec_out_2])
attention = Dot([2,2])([decoded_seq, encoded_seq])
attention = Activation('softmax')(attention)
context = Dot([2,1])([attention, encoded_seq])
dec_combined = Concatenate()([context, decoded_seq])

timeDistributedLayer = model.layers[15]
final_prediction = timeDistributedLayer(dec_combined)

encoder_parameters = [encoded_seq, in_fwd_hidden_1, in_fwd_cell_1, in_bwd_hidden_1, in_bwd_cell_1]
decoder_parameters = [fwd_dec_hid_1, fwd_dec_cell_1, bwd_dec_hid_1, bwd_dec_cell_1]

decoder_model = Model(
    inputs=[decoder_inputs] + encoder_parameters,
    outputs=[final_prediction] + decoder_parameters)

这里是推理循环,我在其中馈入了上一个时间步的状态,还给出了下一个的起始值:

def decode_sequence(input_seq):
    # Encode the input as state vectors.
    encoded_seq, fwd_hid, fwd_cell, bwd_hid, bwd_cell = encoder_model.predict(input_seq)
    states_value = [encoded_seq, fwd_hid, fwd_cell, bwd_hid, bwd_cell]

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index[BOS]] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, next_fwd_hid, next_fwd_cell, next_bwd_hid, next_bwd_cell = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == EOS or len(decoded_sentence) > MAX_LEN):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [encoded_seq, 
                       next_fwd_hid, next_fwd_cell, 
                       next_bwd_hid, next_bwd_cell]

    # decoded sequence without EOS
    return decoded_sentence

0 个答案:

没有答案