添加注意机制后,“无法将TensorArray转换为Tensor或操作”。

时间:2019-04-21 08:11:08

标签: tensorflow recurrent-neural-network

我想在rnn模型中增加注意力机制。

在添加注意力机制之前,我的代码是这样的

> lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.cell_size,
> forget_bias=1.0, state_is_tuple=True)
>         with tf.name_scope('initial_state'):
>             self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
>         self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn( lstm_cell, self.l_in_y, initial_state=self.cell_init_state,
> time_major=False)

运行可以,在添加关注机制后,我的代码如下:

self.cell = tf.contrib.rnn.LSTMCell(self.cell_size)
        self.attention_mechanism = tf.contrib.seq2seq.BahdanauMonotonicAttention(self.cell_size, self.l_in_y)
        with tf.name_scope('audiowarpper'):
            self.attn_cell = tf.contrib.seq2seq.AttentionWrapper(self.cell, self.attention_mechanism, self.cell_size, alignment_history=True, output_attention= True)
        with tf.name_scope('initial_state'):
            self.cell_init_state = self.attn_cell.zero_state(self.batch_size, dtype=tf.float32)
        self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn(
            self.attn_cell, self.l_in_y, initial_state=self.cell_init_state, dtype=tf.float32, time_major=False)

但是当我运行代码时,出现此错误:

Traceback (most recent call last):
  File "/home/wentao/Desktop/wochenende/rnn_weight/rnn2.py", line 178, in <module>
    [model.train_op, model.cost, model.cell_final_state, model.pred],feed_dict=model.feed_dict)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1137, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 471, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 370, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 370, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 271, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 304, in __init__
    (fetch, type(fetch), str(e)))
TypeError: Fetch argument <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x7f11249b8e10> has invalid type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>, must be a string or Tensor. (Can not convert a TensorArray into a Tensor or Operation.)

0 个答案:

没有答案