TensorFlow:图形中的LSTM状态保存/更新

时间:2016-09-06 14:30:18

标签: tensorflow reinforcement-learning

我正在使用强化学习,希望在培训期间减少通过sess.run()提供的数据量,以加快学习速度。

我正在研究LSTM并需要向前看并重置以找到合适的Q值,我用tf.case()制作了一个这样的解决方案:

    CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState')
    with tf.name_scope("LSTMLayer") as scope:
        initializer = tf.random_uniform_initializer(-.1, .1)
        lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True)
        self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True)
        self.state = self.cell_L1.zero_state(1,tf.float64)

        self.SavedState = self.cell_L1.zero_state(1,tf.float64)   #tf.Variable(state, trainable=False, name='SavedState')

        #SaveCond    = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState)
        #RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState)
        #ZeroCond    = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState)

        self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState,
            tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True)

        RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond])

        self.Xinputs = [tf.concat(1,[Xinputs])]

        outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32)
def RestoreState(self):
    #self.state = self.state.assign(self.SavedState)
    self.state = self.SavedState
    return self.state
def ZeroState(self):
    self.state = self.cell_L1.zero_state(1,tf.float64)
    return self.state
def SaveState(self):
    #self.SavedState = self.SavedState.assign(self.state)
    self.SavedState = self.state
    return  self.SavedState
def SameState(self):
    return self.state

这似乎在概念上运作良好,因为现在我可以提供INT来指示LSTM Graph如何处理状态。如果我通过" 1"它会在执行前保存状态,如果我通过" -1"它将恢复上次保存的状态,如果我通过"< -1"它会使国家归零。如果" 0"它将使用上次运行(推理)中LSTM中的内容。我尝试了一些不同的方法,包括一个更简单的tf.cond()方法。

我认为这个问题源于需要张量的tf.case()Op,但LSTM状态是一个元组(非元组将被折旧)。当我尝试将值赋值给图变量时,这一点就变得清晰了。

我的最终目标是离开"状态"在图中,但通过INT来指示如何处理状态。在未来,我希望有多个" store"各种回顾的位置。

如何使用元组与张量处理tf.case()类型的结构?

1 个答案:

答案 0 :(得分:0)

我相信在状态元组中每个元素都有一个tf.case()应该可以工作,因为元组只是一个python元组。