我编写基于tensorflow seq2seq API的代码来构建一个波束搜索解码器,并在下面的代码行结束时出现错误
decoder_initial_state = decoder_initial_state.clone(cell_state=encoded_states)
其中encoded_states和decoder_initial_state由
生成encoded_tupple, encoded_states = bidirectional_dynamic_rnn(lstm_cell_fw, lstm_cell_bw, inputs, dtype=tf.float32, sequence_length=input_sequence_lengths)
encoded_states = tile_batch(encoded_states, multiplier=self.beam_width)
decoder_initial_state = rnn_cell.zero_state(batch_size, tf.float32)
我得到的错误是
ValueError:这两个结构没有相同数量的元素。
第一个结构(6个元素):AttentionWrapperState(cell_state =(LSTMStateTuple(c =,h =),),attention =,time =,alignments =,alignment_history =(),attention_state =)
第二个结构(8个元素):AttentionWrapperState(cell_state =(LSTMStateTuple(c =,h =),LSTMStateTuple(c =,h =)),attention =,time =,alignments =,alignment_history =(),attention_state = )
我想这是因为我在传递给clone()方法之前没有正确打包encoded_states。
我应该怎么做,谢谢!