AttributeError:' Tensor'对象没有属性'注意'

时间:2017-07-14 07:35:10

标签: tensorflow nlp

我试图在Tensorflow中使用dynamic_decode作为注意力模型。原始版本由https://github.com/tensorflow/nmt#decoder

提供
learning_rate = 0.001
n_hidden = 128
total_epoch = 10000
num_units=128
n_class = n_input = 47

num_steps=8
embedding_size=30


mode = tf.placeholder(tf.bool)
embed_enc = tf.placeholder(tf.float32, shape=[None,num_steps,300])
embed_dec = tf.placeholder(tf.float32, shape=[None,num_steps,300])
targets=tf.placeholder(tf.int32, shape=[None,num_steps])

enc_seqlen = tf.placeholder(tf.int32, shape=[None])
dec_seqlen = tf.placeholder(tf.int32, shape=[None])
decoder_weights= tf.placeholder(tf.float32, shape=[None, num_steps])

with tf.variable_scope('encode'):
    enc_cell = tf.contrib.rnn.BasicRNNCell(n_hidden)
    enc_cell = tf.contrib.rnn.DropoutWrapper(enc_cell, output_keep_prob=0.5)
    outputs, enc_states = tf.nn.dynamic_rnn(enc_cell, embed_enc,sequence_length=enc_seqlen, 
                                            dtype=tf.float32,time_major=True )


attention_states = tf.transpose(outputs, [1, 0, 2])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, attention_states,
    memory_sequence_length=enc_seqlen)


decoder_cell = tf.contrib.rnn.BasicLSTMCell(num_units)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell, attention_mechanism,
    attention_layer_size=num_units)

helper = tf.contrib.seq2seq.TrainingHelper(
    embed_dec, dec_seqlen, time_major=True)
# Decoder
projection_layer = Dense(
    47, use_bias=False)
decoder = tf.contrib.seq2seq.BasicDecoder(
    decoder_cell, helper, enc_states,
    output_layer=projection_layer)
# Dynamic decoding
outputs, _ = tf.contrib.seq2seq.dynamic_decode(decoder)

但是当我跑

时我遇到了错误
tf.contrib.seq2seq.dynamic_decode(decoder)

,错误显示如下

    Traceback (most recent call last):

  File "<ipython-input-19-0708495dbbfb>", line 27, in <module>
    outputs, _ = tf.contrib.seq2seq.dynamic_decode(decoder)

  File "D:\Anaconda3\lib\site-packages\tensorflow\contrib\seq2seq\python\ops\decoder.py", line 286, in dynamic_decode
    swap_memory=swap_memory)

  File "D:\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2775, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)

  File "D:\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2604, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)

  File "D:\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2554, in _BuildLoop
    body_result = body(*packed_vars_for_body)

  File "D:\Anaconda3\lib\site-packages\tensorflow\contrib\seq2seq\python\ops\decoder.py", line 234, in body
    decoder_finished) = decoder.step(time, inputs, state)

  File "D:\Anaconda3\lib\site-packages\tensorflow\contrib\seq2seq\python\ops\basic_decoder.py", line 139, in step
    cell_outputs, cell_state = self._cell(inputs, state)

  File "D:\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)

  File "D:\Anaconda3\lib\site-packages\tensorflow\python\layers\base.py", line 450, in __call__
    outputs = self.call(inputs, *args, **kwargs)

  File "D:\Anaconda3\lib\site-packages\tensorflow\contrib\seq2seq\python\ops\attention_wrapper.py", line 1143, in call
    cell_inputs = self._cell_input_fn(inputs, state.attention)

AttributeError: 'Tensor' object has no attribute 'attention'

我尝试安装了最新的tensorflow 1.2.1,但它没有用。 谢谢你的帮助。

更新:

问题是我是否更改了BasicDecoder的initial_states:

decoder = tf.contrib.seq2seq.BasicDecoder(
      decoder_cell, helper, enc_states,
      output_layer=projection_layer)

成:

decoder = tf.contrib.seq2seq.BasicDecoder(
      decoder_cell, helper,  
      decoder_cell.zero_state(dtype=tf.float32,batch_size=batch_size),
      output_layer=projection_layer)

然后它有效。我不知道它是否是一个正确的解决方案,因为initial_states设置为零似乎是有线的。 谢谢你的帮助。

2 个答案:

答案 0 :(得分:0)

你能写下你到处使用kwargs的所有电话吗?即if(test==null || text.isEmpty() || test="test" == 0) { throw new Exception(); } 等。我认为你的args在某处错位,使用kwargs应解决它。

答案 1 :(得分:0)

你的方法是正确的。我为未来的用户在主分支中添加了更好的错误消息。因为你正在使用注意力,你可能不需要将任何东西传递到解码器初始状态。然而,通过向编码器最终状态提供信息仍然很常见。您可以通过以您正在进行的方式创建解码器单元零状态,并使用arg cell_state = encoder_final_state调用其.clone方法来完成此操作。使用结果对象作为初始解码器状态。