Tensorflow:使用BeamSearchDecoder进行故障 - dynamic_decode

时间:2018-03-28 17:24:51

标签: python tensorflow rnn seq2seq

使用注意力和beamsearch实现具有双向multilstm层的seq2seq模型。 (只发布必要的代码以保持简单)

# helper to create the layers
def make_lstm(rnn_size, keep_prob):
lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer = tf.random_uniform_initializer(-0.1, 0.1, seed=2))
lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob = keep_prob)
return lstm_dropout

# helper to create the attention cell with
def decoder_cell(dec_cell, rnn_size, enc_output,  lengths):
    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            num_units              = rnn_size,
            memory                 = enc_output,
            memory_sequence_length = lengths,
            normalize                  = True,
            name  = 'BahdanauAttention')

    return  tf.contrib.seq2seq.AttentionWrapper(
            cell                 = dec_cell,
            attention_mechanism  = attention_mechanism,
            attention_layer_size = rnn_size)

编码器

# foward 
cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(n_layers)])



# backward    
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(n_layers)])



enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                        cell_bw,
                                                        rnn_inputs,
                                                        sequence_length=sequence_length,
                                                        dtype=tf.float32,
                                                        )

enc_output = tf.concat(enc_output,-1)

解码器

beam_width = 10
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
dec_cell = decoder_cell(dec_cell, rnn_size, enc_output, text_length)


with tf.variable_scope("decode"):
    # (dec_embed_input comes from another function but should not be 
    #   relevant in this context. )
    helper = tf.contrib.seq2seq.TrainingHelper(inputs = dec_embed_input, 
                                               sequence_length = summary_length,
                                               time_major = False)

    decoder = tf.contrib.seq2seq.BasicDecoder(cell = dec_cell,
                                              helper = helper,
                                              initial_state = dec_cell.zero_state(batch_size, tf.float32),
                                              output_layer = output_layer)

    logits = tf.contrib.seq2seq.dynamic_decode(decoder=decoder, 
                                           output_time_major=False, 
                                           impute_finished=True, 
                                           maximum_iterations=max_summary_length)




enc_output = tf.contrib.seq2seq.tile_batch(enc_output, multiplier=beam_width)
enc_state = tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
text_length = tf.contrib.seq2seq.tile_batch(text_length, multiplier=beam_width)

dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(num_layers)])
dec_cell = decoder_cell(dec_cell, rnn_size, enc_output, text_length)

start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype = tf.int32), [batch_size], name = 'start_tokens')

with tf.variable_scope("decode", reuse = True):


    decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell,
                                                    embedding=embeddings,
                                                    start_tokens=start_tokens,
                                                    end_token=end_token,
                                                    initial_state=dec_cell.zero_state(batch_size = batch_size*beam_width , dtype = tf.float32),
                                                    beam_width=beam_width,
                                                    output_layer=output_layer,
                                                    length_penalty_weight=0.0)



    logits = tf.contrib.seq2seq.dynamic_decode(decoder=decoder, 
                                           output_time_major=False, 
                                           impute_finished=True, 
                                           maximum_iterations=max_summary_length)

在这一行:

decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell,
                                                    embedding=embeddings,
                                                    start_tokens=start_tokens,
                                                    end_token=end_token,
                                                    initial_state=dec_cell.zero_state(batch_size = batch_size*beam_width , dtype = tf.float32),
                                                    beam_width=beam_width,
                                                    output_layer=output_layer,
                                                    length_penalty_weight=0.0)

我收到以下错误:

ValueError: Shapes must be equal rank, but are 3 and 2 for 'decode_1/decoder/while/Select_4' (op: 'Select') with input shapes: [64,10], [64,10,256], [64,10,256].

有没有人有这方面的经验,或遇到过同样的问题?我真的很感谢你的建议。

Tensorflow:1.6.0 batch_size = 64 rnn_size = 256

2 个答案:

答案 0 :(得分:1)

确保将impute_finished=False传递给dynamic_decode()

答案 1 :(得分:0)

我认为您需要设置解码器init_state = encoder_state

enc_output = tf.contrib.seq2seq.tile_batch(enc_output, multiplier=beam_width)
num_bi_layes = int(num_layers/2)
if num_bi_layes == 1:
    encoder_state = enc_state
else:
    encoder_state = []
    for layer_id in range(num_bi_layes):
        encoder_state.append(enc_state[0][layer_id]) #forward
        encoder_state.append(enc_state[1][layer_id]) #backward
        encoder_state = touple(encoder_state)


encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)

text_length = tf.contrib.seq2seq.tile_batch(text_length, multiplier=beam_width)

dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(num_layers)])
dec_cell = decoder_cell(dec_cell, rnn_size, enc_output, text_length)

start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype = tf.int32), [batch_size], name = 'start_tokens')

with tf.variable_scope("decode", reuse = True):
    decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell,
                                                    embedding=embeddings,
                                                    start_tokens=start_tokens,
                                                    end_token=end_token,
                                                    initial_state=encoder_state,
                                                    beam_width=beam_width,
                                                    output_layer=output_layer,
                                                    length_penalty_weight=0.0)