我正在开发一个文本摘要网络,需要实现一个
与tf.nn.seq2seq.embedding_attention_decoder
一起使用的编码器。作为其中的一部分,我需要将不同批次的序列编码为表示向量,但最内层编码不会通过。
这是一个简化的代码段,它给出了同样的错误:
import tensorflow as tf
single_cell = tf.nn.rnn_cell.GRUCell(1024)
sentence_cell = tf.nn.rnn_cell.EmbeddingWrapper(single_cell,
embedding_classes = 40000)
batch = [tf.placeholder(tf.int32, [1,1]) for _ in range(250)]
(_ , state) = tf.nn.rnn(sentence_cell, batch, dtype= tf.int32)
以下堆栈跟踪失败:
Traceback (most recent call last):
File "/home/ubuntu/workspace/example.py", line 6, in <module>
(_ , state) = tf.nn.rnn(sentence_cell, batch, dtype= tf.int32)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 126, in rnn
(output, state) = call_cell()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 119, in <lambda>
call_cell = lambda: cell(input_, state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 616, in __call__
return self._cell(embedded, state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 150, in __call__
2 * self._num_units, True, 1.0))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 706, in linear
res = math_ops.matmul(array_ops.concat(1, args), matrix)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 314, in concat
name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 70, in _concat
name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 396, in apply_op
raise TypeError("%s that don't all match." % prefix)
TypeError: Tensors in list passed to 'values' of 'Concat' Op have types [float32, int32] that don't all match.
调试sentence_cell
的输入大小为1
时,batch
中的元素都具有维度[1,1]
,这是[batch_size, sentence_cell.input_size]
。
在dtype = tf.float32
的调用中切换为tf.nn.rnn()
会使代码段有效,但在我的代码中给出了以下堆栈跟踪:
[nltk_data] Downloading package punkt to /home/alex/nltk_data...
[nltk_data] Package punkt is already up-to-date!
Preparing news data in .
Creating 3 layers of 1024 units.
> /home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py(84)encode_sentence()
-> (_ ,state) = tf.nn.rnn(sentence_cell, sent, sequence_length = length, dtype= tf.float32)
(Pdb) c
Traceback (most recent call last):
File "translate.py", line 268, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/default/_app.py", line 30, in run
sys.exit(main(sys.argv))
File "translate.py", line 265, in main
train()
File "translate.py", line 161, in train
model = create_model(sess, False)
File "translate.py", line 136, in create_model
forward_only=forward_only)
File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 141, in __init__
softmax_loss_function=None)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/seq2seq.py", line 926, in model_with_buckets
decoder_inputs[:bucket[1]])
File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 140, in <lambda>
lambda x, y: seq3seq_f(x, y, False),
File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 98, in seq3seq_f
art_vecs = tfmap(encode_article, tf.pack(encoder_inputs))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1900, in map
_, r_a = While(lambda i, a: math_ops.less(i, n), compute, [i, acc_ta])
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1557, in While
result = context.BuildLoop(cond, body, loop_vars)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1474, in BuildLoop
body_result = body(*vars_for_body_with_tensor_arrays)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1897, in compute
a = a.write(i, fn(elems_ta.read(i)))
File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 92, in encode_article
return tfmap(encode_sentence, article)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1900, in map
_, r_a = While(lambda i, a: math_ops.less(i, n), compute, [i, acc_ta])
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1557, in While
result = context.BuildLoop(cond, body, loop_vars)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1474, in BuildLoop
body_result = body(*vars_for_body_with_tensor_arrays)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1897, in compute
a = a.write(i, fn(elems_ta.read(i)))
File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 84, in encode_sentence
(_ ,state) = tf.nn.rnn(sentence_cell, sent, sequence_length = length, dtype= tf.float32)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 124, in rnn
zero_output, state, call_cell)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 212, in _rnn_step
time < max_sequence_length, call_cell, empty_update)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1183, in cond
res_t = context_t.BuildCondBranch(fn1)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1106, in BuildCondBranch
r = fn()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 119, in <lambda>
call_cell = lambda: cell(input_, state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 615, in __call__
embedding, array_ops.reshape(inputs, [-1]))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/embedding_ops.py", line 86, in embedding_lookup
validate_indices=validate_indices)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 423, in gather
validate_indices=validate_indices, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 486, in apply_op
_Attr(op_def, input_arg.type_attr))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 59, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: DataType float32 for attr 'Tindices' not in list of allowed values: int32, int64
我错过了什么?