我对this SARSA FA有疑问。
在输入单元格142中,我看到了此修改后的更新
w += alpha * (reward - discount * q_hat_next) * q_hat_grad
其中q_hat_next
是Q(S', a')
,而q_hat_grad
是Q(S, a)
的派生词(假定S, a, R, S' a'
序列)。
我的问题是更新不应该这样吗?
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
修改后的更新背后的直觉是什么?
答案 0 :(得分:0)
我认为你是正确的。我还希望更新中包含TD错误项,应为string = 'WWWWWWWWWWWWBWWWWWWWWWWWWBBBWWWWWWWWWWWWWWWWWWWWWWWWB'
x=''.join(['{}{}'.format(k, sum(1 for _ in g)) for k, g in groupby(string)])
。
作为参考,这是实现:
reward + discount * q_hat_next - q_hat
这是来自Reinforcement Learning: An Introduction (by Sutton & Barto)(第171页)的伪代码:
由于实现为TD(0),所以if done: # (terminal state reached)
w += alpha*(reward - q_hat) * q_hat_grad
break
else:
next_action = policy(env, w, next_state, epsilon)
q_hat_next = approx(w, next_state, next_action)
w += alpha*(reward - discount*q_hat_next)*q_hat_grad
state = next_state
为1。则可以简化伪代码中的更新:
n
成为(通过替换w <- w + a[G - v(S_t,w)] * dv(S_t,w)
)
G == reward + discount*v(S_t+1,w))
或在原始代码示例中使用变量名称:
w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)
我最终得到了与您相同的更新公式。看起来像是非终端状态更新中的错误。
只有末尾的情况(如果w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
为真)才是正确的,因为done
在定义上总是为0,因为情节已经结束,无法获得更多奖励。