在Double DQN(在CNTK中实现)中,我试图使用在线模型计算下一个状态(post_state_var)的值。为了对我的解决方案进行矢量化,我已经使用了one_hot操作。但是,当我尝试训练时,我收到以下错误:
节点" OneHot"可以用于训练,但它不参与梯度传播。
我已将我的模型和输入定义为:
state_var = cntk.input_variable(state_shape, name='state')
action_var = cntk.input_variable(1, name='action')
reward_var = cntk.input_variable(1, name='reward')
post_state_var = cntk.input_variable(state_shape, name='post_state')
terminal_var = cntk.input_variable(1, name='terminal')
with cntk.default_options(activation=relu):
model_fn = Sequential([
Dense(32, name='h1'),
Dense(32, name='h2'),
Dense(action_shape, name='action')
])
model = model_fn(state_var)
target_model = model.clone(cntk.CloneMethod.freeze)
然后我计算目标值并定义损失如下:
# Value of action selected at state t
state_value = cntk.reduce_sum(model * one_hot(action_var, num_classes=action_shape), axis=1)
# Double Q learning - Value of action selected at state t+1
online_post_state_model = model_fn(post_state_var)
online_post_state_best_action = cntk.argmax(online_post_state_model)
post_state_best_value = cntk.reduce_sum(target_model *
one_hot(online_post_state_best_action, num_classes=action_shape))
gamma = 0.99
target = reward_var + (1.0 - terminal_var) * gamma * post_state_best_value
# MSE for simplicity
td_error = state_value - cntk.stop_gradient(target)
loss = cntk.reduce_mean(cntk.square(td_error))
如果我更换
online_post_state_model = model_fn(post_state_var)
带
online_post_state_model = model_fn.clone(cntk.CloneMethod.freeze)(post_state_var)
然后错误消失了,但这是错误的,因为它使用旧的冻结模型来计算目标。如何使用model_fn
评估post_state_var
并排除反向传播的输出?我没有正确使用stop_gradient
吗?
答案 0 :(得分:0)
one_hot
的典型用法是输入数据,您通常不需要反向传播。
解决方法是将操作作为图表中的一个热矢量。您可以使用hardmax
代替argmax
来完成此操作。