从反向传播中排除OneHot操作

时间:2017-07-27 02:29:11

标签: python-3.x backpropagation reinforcement-learning cntk one-hot-encoding

在Double DQN(在CNTK中实现)中,我试图使用在线模型计算下一个状态(post_state_var)的值。为了对我的解决方案进行矢量化,我已经使用了one_hot操作。但是,当我尝试训练时,我收到以下错误:

  

节点" OneHot"可以用于训练,但它不参与梯度传播。

我已将我的模型和输入定义为:

state_var = cntk.input_variable(state_shape, name='state')
action_var = cntk.input_variable(1, name='action')
reward_var = cntk.input_variable(1, name='reward')
post_state_var = cntk.input_variable(state_shape, name='post_state')
terminal_var = cntk.input_variable(1, name='terminal')

with cntk.default_options(activation=relu):
    model_fn = Sequential([
        Dense(32, name='h1'),
        Dense(32, name='h2'),
        Dense(action_shape, name='action')
    ])

model = model_fn(state_var)
target_model = model.clone(cntk.CloneMethod.freeze)

然后我计算目标值并定义损失如下:

# Value of action selected at state t
state_value = cntk.reduce_sum(model * one_hot(action_var, num_classes=action_shape), axis=1)

# Double Q learning - Value of action selected at state t+1
online_post_state_model = model_fn(post_state_var)
online_post_state_best_action = cntk.argmax(online_post_state_model)
post_state_best_value = cntk.reduce_sum(target_model * 
                                        one_hot(online_post_state_best_action, num_classes=action_shape))

gamma = 0.99
target = reward_var + (1.0 - terminal_var) * gamma * post_state_best_value

# MSE for simplicity
td_error = state_value - cntk.stop_gradient(target)
loss = cntk.reduce_mean(cntk.square(td_error))

如果我更换

online_post_state_model = model_fn(post_state_var)

online_post_state_model = model_fn.clone(cntk.CloneMethod.freeze)(post_state_var)

然后错误消失了,但这是错误的,因为它使用旧的冻结模型来计算目标。如何使用model_fn评估post_state_var并排除反向传播的输出?我没有正确使用stop_gradient吗?

1 个答案:

答案 0 :(得分:0)

one_hot的典型用法是输入数据,您通常不需要反向传播。

解决方法是将操作作为图表中的一个热矢量。您可以使用hardmax代替argmax来完成此操作。