如何在Tensorflow中初始化LSTMStateTuple?

时间:2018-06-07 13:36:10

标签: python tensorflow

我是Tensorflow的新手,现在尝试使用Tensorflow中的RNN语言模型生成句子。然而,当我尝试实现波束搜索算法时,LSTM状态需要由生成高概率句子的状态初始化。我怎样才能做到这一点?任何帮助,将不胜感激。 这是我的代码:

for i in range(max_gen_len):
    feed_dict = {model.input_data: x_batch,
                 model._initial_state: state}
    fetches = {"probs": model.probs,
               "log_probs": model.log_probs,
               "state": model.final_state}
    vals = sess.run(fetches, feed_dict)
    log_p = np.reshape(vals['log_probs'], [vals['log_probs'].shape[0], vals['log_probs'].shape[2]])
    state = vals['state'][-1]

    cand_scores = hyp_scores[:, None] + log_p
    vocab_size = log_p.shape[1]
    cand_flat = cand_scores.flatten()

    this_count = FLAGS.beam_size - dead_k
    ranks_flat = cand_flat.argsort()[(-this_count):]
    trans_idx = ranks_flat / vocab_size
    word_idx = ranks_flat % vocab_size

    state_c, state_h = state[0], state[1]
    new_hyp_samples = []
    new_hyp_scores = np.zeros(this_count).astype('float32')

    new_hyp_states_c = []
    new_hyp_states_h = []
    for idx, [ti, wi] in enumerate(zip(trans_idx, word_idx)):
         new_hyp_samples.append(hyp_samples[ti] + [wi])
         new_hyp_scores[idx] = copy.deepcopy(cand_scores[ti][wi])
         new_hyp_states_c.append(state_c[ti])
         new_hyp_states_h.append(state_h[ti])
         tf.contrib.rnn.LSTMStateTuple(np.array(new_hyp_states_c), 
         np.array(new_hyp_states_h))

         new_live_k = 0

         hyp_samples = []
         hyp_scores = []
         hyp_states_c = [[[]]]
         hyp_states_h = [[[]]]

         state = model._initial_state

         for idx in range(len(new_hyp_samples)):
             if vocab.id2char(new_hyp_samples[idx][-1]) == '</s>':
                 sample.append(new_hyp_samples[idx])
                 sample_score.append(new_hyp_scores[idx])
                 dead_k += 1
             else:
                 new_live_k += 1
                 hyp_samples.append(new_hyp_samples[idx])

                 hyp_states_c[0][0].append(new_hyp_states_c[idx])
                 hyp_states_h[0][0].append(new_hyp_states_h[idx])
                 hyp_scores.append(new_hyp_scores[idx])
         state = tf.contrib.rnn.LSTMStateTuple(hyp_states_c, hyp_states_h)

当反馈初始化状态时,得到了这样的错误:

Traceback (most recent call last):
File "/data00/home/zhaodongdi/workspace/lab- 
speech/asr_rnnlm/scripts/../src/rnnlm.py", line 579, in <module>
    tf.app.run()
  File "/data00/home/wupeihao/anaconda2/lib/python2.7/site- 
packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
  File "/data00/home/zhaodongdi/workspace/lab- 
speech/asr_rnnlm/scripts/../src/rnnlm.py", line 573, in main
gen(config)
File "/data00/home/zhaodongdi/workspace/lab- 
speech/asr_rnnlm/scripts/../src/rnnlm.py", line 539, in gen
    gen_sum=1000)
 File "/data00/home/zhaodongdi/workspace/lab-speech/asr_rnnlm/scripts/../src/rnnlm.py", line 189, in 
batch_generation
    vals = sess.run(fetches, feed_dict)
  File "/data00/home/wupeihao/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/data00/home/wupeihao/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1070, in _run
    feed_dict = nest.flatten_dict_items(feed_dict)
   File "/data00/home/wupeihao/anaconda2/lib/python2.7/site- 
packages/tensorflow/python/util/nest.py", line 232, in 
 flatten_dict_items
    % (len(flat_i), len(flat_v), flat_i, flat_v))
ValueError: Could not flatten dictionary. Key had 2 elements, but 
value had 20 elements. Key: [<tf.Tensor 'Generate/Model/model/MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros:0' shape=(10, 640) dtype=float32>, <tf.Tensor 'Generate/Model/model/MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros_1:0' shape=(10, 640) dtype=float32>], value: 

0 个答案:

没有答案