我是Tensorflow的新手,现在尝试使用Tensorflow中的RNN语言模型生成句子。然而,当我尝试实现波束搜索算法时,LSTM状态需要由生成高概率句子的状态初始化。我怎样才能做到这一点?任何帮助,将不胜感激。 这是我的代码:
for i in range(max_gen_len):
feed_dict = {model.input_data: x_batch,
model._initial_state: state}
fetches = {"probs": model.probs,
"log_probs": model.log_probs,
"state": model.final_state}
vals = sess.run(fetches, feed_dict)
log_p = np.reshape(vals['log_probs'], [vals['log_probs'].shape[0], vals['log_probs'].shape[2]])
state = vals['state'][-1]
cand_scores = hyp_scores[:, None] + log_p
vocab_size = log_p.shape[1]
cand_flat = cand_scores.flatten()
this_count = FLAGS.beam_size - dead_k
ranks_flat = cand_flat.argsort()[(-this_count):]
trans_idx = ranks_flat / vocab_size
word_idx = ranks_flat % vocab_size
state_c, state_h = state[0], state[1]
new_hyp_samples = []
new_hyp_scores = np.zeros(this_count).astype('float32')
new_hyp_states_c = []
new_hyp_states_h = []
for idx, [ti, wi] in enumerate(zip(trans_idx, word_idx)):
new_hyp_samples.append(hyp_samples[ti] + [wi])
new_hyp_scores[idx] = copy.deepcopy(cand_scores[ti][wi])
new_hyp_states_c.append(state_c[ti])
new_hyp_states_h.append(state_h[ti])
tf.contrib.rnn.LSTMStateTuple(np.array(new_hyp_states_c),
np.array(new_hyp_states_h))
new_live_k = 0
hyp_samples = []
hyp_scores = []
hyp_states_c = [[[]]]
hyp_states_h = [[[]]]
state = model._initial_state
for idx in range(len(new_hyp_samples)):
if vocab.id2char(new_hyp_samples[idx][-1]) == '</s>':
sample.append(new_hyp_samples[idx])
sample_score.append(new_hyp_scores[idx])
dead_k += 1
else:
new_live_k += 1
hyp_samples.append(new_hyp_samples[idx])
hyp_states_c[0][0].append(new_hyp_states_c[idx])
hyp_states_h[0][0].append(new_hyp_states_h[idx])
hyp_scores.append(new_hyp_scores[idx])
state = tf.contrib.rnn.LSTMStateTuple(hyp_states_c, hyp_states_h)
当反馈初始化状态时,得到了这样的错误:
Traceback (most recent call last):
File "/data00/home/zhaodongdi/workspace/lab-
speech/asr_rnnlm/scripts/../src/rnnlm.py", line 579, in <module>
tf.app.run()
File "/data00/home/wupeihao/anaconda2/lib/python2.7/site-
packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "/data00/home/zhaodongdi/workspace/lab-
speech/asr_rnnlm/scripts/../src/rnnlm.py", line 573, in main
gen(config)
File "/data00/home/zhaodongdi/workspace/lab-
speech/asr_rnnlm/scripts/../src/rnnlm.py", line 539, in gen
gen_sum=1000)
File "/data00/home/zhaodongdi/workspace/lab-speech/asr_rnnlm/scripts/../src/rnnlm.py", line 189, in
batch_generation
vals = sess.run(fetches, feed_dict)
File "/data00/home/wupeihao/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/data00/home/wupeihao/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1070, in _run
feed_dict = nest.flatten_dict_items(feed_dict)
File "/data00/home/wupeihao/anaconda2/lib/python2.7/site-
packages/tensorflow/python/util/nest.py", line 232, in
flatten_dict_items
% (len(flat_i), len(flat_v), flat_i, flat_v))
ValueError: Could not flatten dictionary. Key had 2 elements, but
value had 20 elements. Key: [<tf.Tensor 'Generate/Model/model/MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros:0' shape=(10, 640) dtype=float32>, <tf.Tensor 'Generate/Model/model/MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros_1:0' shape=(10, 640) dtype=float32>], value: