我想在rnn模型中增加注意力机制。
在添加注意力机制之前,我的代码是这样的
> lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.cell_size,
> forget_bias=1.0, state_is_tuple=True)
> with tf.name_scope('initial_state'):
> self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
> self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn( lstm_cell, self.l_in_y, initial_state=self.cell_init_state,
> time_major=False)
运行可以,在添加关注机制后,我的代码如下:
self.cell = tf.contrib.rnn.LSTMCell(self.cell_size)
self.attention_mechanism = tf.contrib.seq2seq.BahdanauMonotonicAttention(self.cell_size, self.l_in_y)
with tf.name_scope('audiowarpper'):
self.attn_cell = tf.contrib.seq2seq.AttentionWrapper(self.cell, self.attention_mechanism, self.cell_size, alignment_history=True, output_attention= True)
with tf.name_scope('initial_state'):
self.cell_init_state = self.attn_cell.zero_state(self.batch_size, dtype=tf.float32)
self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn(
self.attn_cell, self.l_in_y, initial_state=self.cell_init_state, dtype=tf.float32, time_major=False)
但是当我运行代码时,出现此错误:
Traceback (most recent call last):
File "/home/wentao/Desktop/wochenende/rnn_weight/rnn2.py", line 178, in <module>
[model.train_op, model.cost, model.cell_final_state, model.pred],feed_dict=model.feed_dict)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1137, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 471, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
return _ListFetchMapper(fetch)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 370, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
return _ListFetchMapper(fetch)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 370, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 271, in for_fetch
return _ElementFetchMapper(fetches, contraction_fn)
File "/home/wentao/PycharmProjects/LAS/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 304, in __init__
(fetch, type(fetch), str(e)))
TypeError: Fetch argument <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x7f11249b8e10> has invalid type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>, must be a string or Tensor. (Can not convert a TensorArray into a Tensor or Operation.)