我正在尝试使用bidirectional_dynamic_decode,Attention和Tensorflow中的BeamSearchDecoder(1.6.0)来实现seq2seq模型。 (我试图只复制相关代码,以保持简单)
# encoder
def make_lstm(rnn_size, keep_prob):
lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer =
tf.truncated_normal_initializer(mean = 0.0, stddev = 1.0))
lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob
= keep_prob)
return lstm_dropout
cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(n_layers)])
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(n_layers)])
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length=sequence_length,
dtype=tf.float32)
enc_output = tf.concat(enc_output,2)
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer =
tf.truncated_normal_initializer(mean = 0.0,
stddev=0.1))
# training_decoding_layer
with tf.variable_scope('decode'):
....
# inference_decoding_layer
with tf.variable_scope('decode', reuse = True):
beam_width = 10
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(enc_output,
multiplier=beam_width)
tiled_encoder_final_state =
tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(text_length,
multiplier=beam_width)
start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype =
tf.int32), [batch_size], name = 'start_tokens')
attn_mech = tf.contrib.seq2seq.BahdanauAttention( num_units =
rnn_size,
memory =
tiled_encoder_outputs,
memory_sequence_length=tiled_sequence_length,
normalize=True )
beam_dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell,
attn_mech, rnn_size)
beam_initial_state = beam_dec_cell.zero_state(batch_size =
batch_size*beam_width , dtype = tf.float32)
beam_initial_state =
beam_initial_state.clone(cell_state=tiled_encoder_final_state)
然而,当我尝试将编码器的最后一个状态克隆到&#39; beam_initial_state&#39;上图中的变量我得到以下错误:
ValueError: The two structures don't have the same number of elements.
First structure (6 elements): AttentionWrapperState(cell_state= .
(LSTMStateTuple(c=<tf.Tensor
'decode_1/AttentionWrapperZeroState/checked_cell_state:0' shape=(640,
256) dtype=float32>, h=<tf.Tensor
'decode_1/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(640,
256) dtype=float32>),), attention=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros_1:0' shape=(640, 256)
dtype=float32>, time=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>,
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0'
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state=
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640,
?) dtype=float32>)
Second structure (8 elements): AttentionWrapperState(cell_state= .
((LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape:0' shape=
(?, 256) dtype=float32>, h=<tf.Tensor
'decode_1/tile_batch_1/Reshape_1:0' shape=(?, 256) dtype=float32>),),
(LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape_2:0' shape=
(?, 256) dtype=float32>, h=<tf.Tensor
'decode_1/tile_batch_1/Reshape_3:0' shape=(?, 256) dtype=float32>),)),
attention=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_1:0'
shape=(640, 256) dtype=float32>, time=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>,
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0'
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= .
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640,
?) dtype=float32>)
有人有什么建议吗?非常感谢。
答案 0 :(得分:0)
您需要为每个MultiRNNCell手动添加(或连接)向前和向后状态:
def add_stacked_cell_state(forward_state, backword_state, useGRUCell):
temp_list = []
for state_fw, state_bw in zip(forward_state, backword_state):
if useGRUCell:
temp_list.append(tf.add(state_fw, state_bw))
stacked_state = tuple(temp_list)
else:
temp_list2 = []
for hidden_fw, hidden_bw in zip(state_fw, state_bw):
temp_list2.append(tf.add(hidden_fw, hidden_bw))
LSTMtuple = tf.contrib.rnn.LSTMStateTuple(*temp_list2)
temp_list.append(LSTMtuple)
stacked_state = tuple(temp_list)
return stacked_state
然后应用它:
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length=sequence_length,
dtype=tf.float32)
enc_state = add_stacked_cell_state(enc_state[0], enc_state[1], useGRUCell=False)
enc_output = tf.add(enc_output[0],enc_output[1])