在dynamic_rnn()中,在状态中包含变量是否有效?

时间:2017-01-31 17:51:13

标签: tensorflow

我在BasicLSTMCell周围实现一个RNN单元,我希望能够回顾过去的隐藏状态(跨批处理边界)。我使用dynamic_rnn(),我使用的基本模式是:

def __call__(self, inputs, old_state, scope=None):
    mem = old_state[2]

    # [do something with with mem]

    cell_out, new_state = self.cell(inputs,
                                    (old_state[0],
                                     old_state[1]))

    h_state = new_state.h
    c_state = new_state.c

    # control dependency required because of self.buf_index
    with tf.get_default_graph().control_dependencies([cell_out]):
        new_mem = write_to_buf(self.out_buf,
                               cell_out,
                               self.buf_index)

    # update the buffer index
    with tf.get_default_graph().control_dependencies(new_mem):
        inc_step = tf.assign(self.buf_index, (self.buf_index + 1) %
                             self.buf_size)
        with tf.get_default_graph().control_dependencies([inc_step]):
            h_state = tf.identity(h_state)
    t = [c_state, h_state, new_mem]
    return cell_out, tuple(t)

self.bufself.buf_index是变量。 write_to_buf()是一个函数,它使用scatter_update()将新的隐藏状态写入缓冲区并返回结果。

我依赖于对散布更新的访问返回值的假设,保证使用新的变量值(类似于this),这样变量的缓存不会搞砸。

从调试打印出来似乎可以正常工作,但最好能获得一些关于替代品的确认或建议。

0 个答案:

没有答案