我可以在RNNCell的__call__中存储状态

时间:2017-03-06 04:20:05

标签: tensorflow

我想构建自己的RNNCell(对于输出为[-1,0,1]的非常简单的单元格),但是在我的__call__我做的计算依赖于前一步的输出。

所以我的问题是,是否可以在__call__方法中保持状态以便在调用之间重用?

class MyCell(RNNCell):
    # Size of my state
    # My state consists of 1 tensor with num_units columns
    @property
    def state_size(self):
        return self._num_units

    # I emit at every timestep
    @property
    def output_size(self):
        return self._num_units

    def __call__(self,input,state):
        #Intermediate calculations for 1 time step
        #Can i keep state here, for example info about last input
        #or output?
        return output, new_state

1 个答案:

答案 0 :(得分:1)

这不适用于tf.dynamic_rnn,因此不鼓励这样做。如果您希望它工作,则通过state参数传递所有状态。它可能适用于普通tf.rnn,但不能保证。