培训进展顺利,但当我尝试使用模型作者提供的玩具数据运行“解码”模式时,我收到此错误:
Traceback (most recent call last):
File "/home/pavel/Sandbox/TensorFlow/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 212, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 30, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "/home/pavel/Sandbox/TensorFlow/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 208, in main
decoder.DecodeLoop()
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/seq2seq_attention_decode.py", line 101, in DecodeLoop
if not self._Decode(self._saver, sess):
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/seq2seq_attention_decode.py", line 140, in _Decode
best_beam = bs.BeamSearch(sess, article_batch_cp, article_lens_cp)[0]
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/beam_search.py", line 113, in BeamSearch
sess, latest_tokens, enc_top_states, states)
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/seq2seq_attention_model.py", line 283, in decode_topk
feed_dict=feed)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 864, in _run
feed_dict = nest.flatten_dict_items(feed_dict)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/nest.py", line 186, in flatten_dict_items
% (len(flat_i), len(flat_v), flat_i, flat_v))
ValueError: Could not flatten dictionary. Key had 2 elements, but value had 1 elements. Key: [<tf.Tensor 'seq2seq/encoder3/BiRNN/FW/FW/cond_119/Merge_1:0' shape=(8, 256) dtype=float32>, <tf.Tensor 'seq2seq/encoder3/BiRNN/FW/FW/cond_119/Merge_2:0' shape=(8, 256) dtype=float32>], value: [array([[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
...,
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]]], dtype=float32)].
我运行的命令解码:
bazel-bin/textsum/seq2seq_attention --mode=decode --article_key=article --abstract_key=abstract --data_path=data/predict --vocab_path=data/vocab --log_root=log_root --decode_dir=log_root/decode --beam_size=8 --truncate_input=True
它可能是什么原因?
CUDA 7.5
CUDNN 5.1
TensorFlow 0.10
更新:我尝试安装以前版本的TensorFlow:0.9,基于对GitHub问题的评论:https://github.com/tensorflow/models/issues/417 它有助于解决这个问题。 我仍然不知道为什么它不适用于版本0.10。