车杆的SARSA值近似值

时间:2018-07-17 01:26:09

标签: machine-learning reinforcement-learning openai-gym sarsa

我对this SARSA FA有疑问。

在输入单元格142中,我看到了此修改后的更新

w += alpha * (reward - discount * q_hat_next) * q_hat_grad

其中q_hat_nextQ(S', a'),而q_hat_gradQ(S, a)的派生词(假定S, a, R, S' a'序列)。

我的问题是更新不应该这样吗?

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad

修改后的更新背后的直觉是什么?

1 个答案:

答案 0 :(得分:0)

我认为你是正确的。我还希望更新中包含TD错误项,应为string = 'WWWWWWWWWWWWBWWWWWWWWWWWWBBBWWWWWWWWWWWWWWWWWWWWWWWWB' x=''.join(['{}{}'.format(k, sum(1 for _ in g)) for k, g in groupby(string)])

作为参考,这是实现:

reward + discount * q_hat_next - q_hat

这是来自Reinforcement Learning: An Introduction (by Sutton & Barto)(第171页)的伪代码:

enter image description here

由于实现为TD(0),所以if done: # (terminal state reached) w += alpha*(reward - q_hat) * q_hat_grad break else: next_action = policy(env, w, next_state, epsilon) q_hat_next = approx(w, next_state, next_action) w += alpha*(reward - discount*q_hat_next)*q_hat_grad state = next_state 为1。则可以简化伪代码中的更新:

n

成为(通过替换w <- w + a[G - v(S_t,w)] * dv(S_t,w)

G == reward + discount*v(S_t+1,w))

或在原始代码示例中使用变量名称:

w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)

我最终得到了与您相同的更新公式。看起来像是非终端状态更新中的错误。

只有末尾的情况(如果w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad 为真)才是正确的,因为done在定义上总是为0,因为情节已经结束,无法获得更多奖励。