通过TensorFlow中的嵌入包装器提供数据

时间:2016-03-31 10:22:30

标签: python tensorflow

我正在开发一个文本摘要网络,需要实现一个 与tf.nn.seq2seq.embedding_attention_decoder一起使用的编码器。作为其中的一部分,我需要将不同批次的序列编码为表示向量,但最内层编码不会通过。

这是一个简化的代码段,它给出了同样的错误:

import tensorflow as tf                                               

single_cell = tf.nn.rnn_cell.GRUCell(1024)                            
sentence_cell = tf.nn.rnn_cell.EmbeddingWrapper(single_cell,
                                                embedding_classes = 40000)                                                     
batch = [tf.placeholder(tf.int32, [1,1]) for _ in range(250)]       
(_ , state) = tf.nn.rnn(sentence_cell, batch, dtype= tf.int32)

以下堆栈跟踪失败:

Traceback (most recent call last):                                                                                                                                                                                                     
  File "/home/ubuntu/workspace/example.py", line 6, in <module>                                                                                                                                                                        
    (_ , state) = tf.nn.rnn(sentence_cell, batch, dtype= tf.int32)                                                                                                                                                                     
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 126, in rnn                                                                                                                                         
    (output, state) = call_cell()                                                                                                                                                                                                      
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 119, in <lambda>                                                                                                                                    
    call_cell = lambda: cell(input_, state)                                                                                                                                                                                            
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 616, in __call__                                                                                                                               
    return self._cell(embedded, state)                                                                                                                                                                                                 
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 150, in __call__                                                                                                                               
    2 * self._num_units, True, 1.0))                                                                                                                                                                                                   
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 706, in linear                                                                                                                                 
    res = math_ops.matmul(array_ops.concat(1, args), matrix)                                                                                                                                                                           
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 314, in concat                                                                                                                                
    name=name)                                                                                                                                                                                                                         
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 70, in _concat                                                                                                                            
    name=name)                                                                                                                                                                                                                         
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 396, in apply_op                                                                                                                         
    raise TypeError("%s that don't all match." % prefix)                                                                                                                                                                               
TypeError: Tensors in list passed to 'values' of 'Concat' Op have types [float32, int32] that don't all match. 

调试sentence_cell的输入大小为1时,batch中的元素都具有维度[1,1],这是[batch_size, sentence_cell.input_size]

dtype = tf.float32的调用中切换为tf.nn.rnn()会使代码段有效,但在我的代码中给出了以下堆栈跟踪:

[nltk_data] Downloading package punkt to /home/alex/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Preparing news data in .
Creating 3 layers of 1024 units.
> /home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py(84)encode_sentence()
-> (_ ,state) = tf.nn.rnn(sentence_cell, sent, sequence_length = length, dtype= tf.float32)
(Pdb) c
Traceback (most recent call last):
  File "translate.py", line 268, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/default/_app.py", line 30, in run
    sys.exit(main(sys.argv))
  File "translate.py", line 265, in main
    train()
  File "translate.py", line 161, in train
    model = create_model(sess, False)
  File "translate.py", line 136, in create_model
    forward_only=forward_only)
  File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 141, in __init__
    softmax_loss_function=None)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/seq2seq.py", line 926, in model_with_buckets
    decoder_inputs[:bucket[1]])
  File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 140, in <lambda>
    lambda x, y: seq3seq_f(x, y, False),
  File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 98, in seq3seq_f
    art_vecs = tfmap(encode_article, tf.pack(encoder_inputs))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1900, in map
    _, r_a = While(lambda i, a: math_ops.less(i, n), compute, [i, acc_ta])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1557, in While
    result = context.BuildLoop(cond, body, loop_vars)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1474, in BuildLoop
    body_result = body(*vars_for_body_with_tensor_arrays)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1897, in compute
    a = a.write(i, fn(elems_ta.read(i)))
  File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 92, in encode_article
    return tfmap(encode_sentence, article)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1900, in map
    _, r_a = While(lambda i, a: math_ops.less(i, n), compute, [i, acc_ta])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1557, in While
    result = context.BuildLoop(cond, body, loop_vars)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1474, in BuildLoop
    body_result = body(*vars_for_body_with_tensor_arrays)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1897, in compute
    a = a.write(i, fn(elems_ta.read(i)))
  File "/home/alex/Programmering/kandidatarbete/arbete/code/seq3seq/seq3seq_model.py", line 84, in encode_sentence
    (_ ,state) = tf.nn.rnn(sentence_cell, sent, sequence_length = length, dtype= tf.float32)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 124, in rnn
    zero_output, state, call_cell)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 212, in _rnn_step
    time < max_sequence_length, call_cell, empty_update)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1183, in cond
    res_t = context_t.BuildCondBranch(fn1)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1106, in BuildCondBranch
    r = fn()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 119, in <lambda>
    call_cell = lambda: cell(input_, state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 615, in __call__
    embedding, array_ops.reshape(inputs, [-1]))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/embedding_ops.py", line 86, in embedding_lookup
    validate_indices=validate_indices)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 423, in gather
    validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 486, in apply_op
    _Attr(op_def, input_arg.type_attr))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 59, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: DataType float32 for attr 'Tindices' not in list of allowed values: int32, int64

我错过了什么?

0 个答案:

没有答案