Tensorflow:使用Attention和BeamSearch在seq2seq模型中使用.clone()进行故障排除

时间:2018-03-27 14:07:32

标签: python-3.x tensorflow beam-search seq2seq

我正在尝试使用bidirectional_dynamic_decode,Attention和Tensorflow中的BeamSearchDecoder(1.6.0)来实现seq2seq模型。 (我试图只复制相关代码,以保持简单)

# encoder
def make_lstm(rnn_size, keep_prob):
    lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer = 
    tf.truncated_normal_initializer(mean = 0.0, stddev = 1.0))
    lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob 
    = keep_prob)
    return lstm_dropout    

cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) 
          for _ in range(n_layers)])
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) 
          for _ in range(n_layers)])
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                        cell_bw,
                                                        rnn_inputs, 
                                       sequence_length=sequence_length, 
                                                      dtype=tf.float32)
enc_output = tf.concat(enc_output,2)


dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) 
                                        for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer = 
                     tf.truncated_normal_initializer(mean = 0.0, 
                      stddev=0.1))

# training_decoding_layer
with tf.variable_scope('decode'):
....

# inference_decoding_layer
with tf.variable_scope('decode', reuse = True):
    beam_width = 10
    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(enc_output, 
     multiplier=beam_width)
    tiled_encoder_final_state = 
    tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(text_length, 
    multiplier=beam_width)
    start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype = 
    tf.int32), [batch_size], name = 'start_tokens')
    attn_mech = tf.contrib.seq2seq.BahdanauAttention( num_units = 
    rnn_size, 
                                                      memory = 
    tiled_encoder_outputs,

    memory_sequence_length=tiled_sequence_length,
                                                      normalize=True )

    beam_dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell, 
    attn_mech, rnn_size)  

    beam_initial_state = beam_dec_cell.zero_state(batch_size = 
    batch_size*beam_width , dtype = tf.float32)
    beam_initial_state = 
    beam_initial_state.clone(cell_state=tiled_encoder_final_state)

然而,当我尝试将编码器的最后一个状态克隆到&#39; beam_initial_state&#39;上图中的变量我得到以下错误:

ValueError: The two structures don't have the same number of elements.

First structure (6 elements): AttentionWrapperState(cell_state= . 
(LSTMStateTuple(c=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/checked_cell_state:0' shape=(640, 
256) dtype=float32>, h=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(640, 
256) dtype=float32>),), attention=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/zeros_1:0' shape=(640, 256) 
dtype=float32>, time=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, 
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0' 
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= 
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640, 
?) dtype=float32>)

Second structure (8 elements): AttentionWrapperState(cell_state= . 
((LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape:0' shape= 
(?, 256) dtype=float32>, h=<tf.Tensor 
'decode_1/tile_batch_1/Reshape_1:0' shape=(?, 256) dtype=float32>),), 
(LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape_2:0' shape= 
(?, 256) dtype=float32>, h=<tf.Tensor 
'decode_1/tile_batch_1/Reshape_3:0' shape=(?, 256) dtype=float32>),)), 
attention=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_1:0' 
shape=(640, 256) dtype=float32>, time=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, 
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0' 
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= . 
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640, 
?) dtype=float32>)

有人有什么建议吗?非常感谢。

1 个答案:

答案 0 :(得分:0)

您需要为每个MultiRNNCell手动添加(或连接)向前和向后状态:

def add_stacked_cell_state(forward_state, backword_state, useGRUCell):
    temp_list = []
    for state_fw, state_bw in zip(forward_state, backword_state):
        if useGRUCell:
            temp_list.append(tf.add(state_fw, state_bw))
            stacked_state = tuple(temp_list)
        else:
            temp_list2 = []
            for hidden_fw, hidden_bw in zip(state_fw, state_bw):
                temp_list2.append(tf.add(hidden_fw, hidden_bw))

            LSTMtuple = tf.contrib.rnn.LSTMStateTuple(*temp_list2)
            temp_list.append(LSTMtuple)

    stacked_state = tuple(temp_list)
    return stacked_state

然后应用它:

enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                    cell_bw,
                                                    rnn_inputs, 
                                   sequence_length=sequence_length, 
                                                  dtype=tf.float32)

enc_state = add_stacked_cell_state(enc_state[0], enc_state[1], useGRUCell=False)

enc_output = tf.add(enc_output[0],enc_output[1])