Tensorflow BeamSearchDecoder如何处理initial_state

时间:2018-05-03 21:34:52

标签: tensorflow

我编写基于tensorflow seq2seq API的代码来构建一个波束搜索解码器,并在下面的代码行结束时出现错误

decoder_initial_state = decoder_initial_state.clone(cell_state=encoded_states)    

其中encoded_states和decoder_initial_state由

生成
encoded_tupple, encoded_states = bidirectional_dynamic_rnn(lstm_cell_fw, lstm_cell_bw, inputs, dtype=tf.float32, sequence_length=input_sequence_lengths)
encoded_states = tile_batch(encoded_states, multiplier=self.beam_width) 

decoder_initial_state = rnn_cell.zero_state(batch_size, tf.float32)

我得到的错误是

ValueError:这两个结构没有相同数量的元素。

第一个结构(6个元素):AttentionWrapperState(cell_state =(LSTMStateTuple(c =,h =),),attention =,time =,alignments =,alignment_history =(),attention_state =)

第二个结构(8个元素):AttentionWrapperState(cell_state =(LSTMStateTuple(c =,h =),LSTMStateTuple(c =,h =)),attention =,time =,alignments =,alignment_history =(),attention_state = )

我想这是因为我在传递给clone()方法之前没有正确打包encoded_states。

我应该怎么做,谢谢!

0 个答案:

没有答案