我正在使用强化学习,希望在培训期间减少通过sess.run()提供的数据量,以加快学习速度。
我正在研究LSTM并需要向前看并重置以找到合适的Q值,我用tf.case()制作了一个这样的解决方案:
CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState')
with tf.name_scope("LSTMLayer") as scope:
initializer = tf.random_uniform_initializer(-.1, .1)
lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True)
self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True)
self.state = self.cell_L1.zero_state(1,tf.float64)
self.SavedState = self.cell_L1.zero_state(1,tf.float64) #tf.Variable(state, trainable=False, name='SavedState')
#SaveCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState)
#RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState)
#ZeroCond = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState)
self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState,
tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True)
RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond])
self.Xinputs = [tf.concat(1,[Xinputs])]
outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32)
def RestoreState(self): #self.state = self.state.assign(self.SavedState) self.state = self.SavedState return self.state def ZeroState(self): self.state = self.cell_L1.zero_state(1,tf.float64) return self.state def SaveState(self): #self.SavedState = self.SavedState.assign(self.state) self.SavedState = self.state return self.SavedState def SameState(self): return self.state
这似乎在概念上运作良好,因为现在我可以提供INT来指示LSTM Graph如何处理状态。如果我通过" 1"它会在执行前保存状态,如果我通过" -1"它将恢复上次保存的状态,如果我通过"< -1"它会使国家归零。如果" 0"它将使用上次运行(推理)中LSTM中的内容。我尝试了一些不同的方法,包括一个更简单的tf.cond()方法。
我认为这个问题源于需要张量的tf.case()Op,但LSTM状态是一个元组(非元组将被折旧)。当我尝试将值赋值给图变量时,这一点就变得清晰了。
我的最终目标是离开"状态"在图中,但通过INT来指示如何处理状态。在未来,我希望有多个" store"各种回顾的位置。
如何使用元组与张量处理tf.case()类型的结构?
答案 0 :(得分:0)
我相信在状态元组中每个元素都有一个tf.case()应该可以工作,因为元组只是一个python元组。