我在BasicLSTMCell周围实现一个RNN单元,我希望能够回顾过去的隐藏状态(跨批处理边界)。我使用dynamic_rnn()
,我使用的基本模式是:
def __call__(self, inputs, old_state, scope=None):
mem = old_state[2]
# [do something with with mem]
cell_out, new_state = self.cell(inputs,
(old_state[0],
old_state[1]))
h_state = new_state.h
c_state = new_state.c
# control dependency required because of self.buf_index
with tf.get_default_graph().control_dependencies([cell_out]):
new_mem = write_to_buf(self.out_buf,
cell_out,
self.buf_index)
# update the buffer index
with tf.get_default_graph().control_dependencies(new_mem):
inc_step = tf.assign(self.buf_index, (self.buf_index + 1) %
self.buf_size)
with tf.get_default_graph().control_dependencies([inc_step]):
h_state = tf.identity(h_state)
t = [c_state, h_state, new_mem]
return cell_out, tuple(t)
self.buf
和self.buf_index
是变量。 write_to_buf()
是一个函数,它使用scatter_update()
将新的隐藏状态写入缓冲区并返回结果。
我依赖于对散布更新的访问返回值的假设,保证使用新的变量值(类似于this),这样变量的缓存不会搞砸。
从调试打印出来似乎可以正常工作,但最好能获得一些关于替代品的确认或建议。