这是计算火炬中两个不同NN的两个损耗的梯度的正确方法吗?

时间:2020-05-30 13:04:21

标签: python pytorch loss-function

我在pytorch中定义了一个NN,并创建了该网络的两个实例,分别为self.actor_critic_r1self.actor_critic_r2。我计算每个网的损失,即loss1loss2,然后将其汇总并按照以下方式计算等级,

loss_r1 = value_loss_r1 + action_loss_r1 - dist_entropy_r1 * args.entropy_coef
loss_r2 = value_loss_r2 + action_loss_r2 - dist_entropy_r2 * args.entropy_coef
self.optimizer_r1.zero_grad()
self.optimizer_r2.zero_grad()
loss = loss_r1 + loss_r2
loss.backward()
self.optimizer_r1.step()
self.optimizer_r2.step()
clip_grad_norm_(self.actor_critic_r1.parameters(), args.max_grad_norm)
clip_grad_norm_(self.actor_critic_r2.parameters(), args.max_grad_norm)

或者,我应该这样单独更新损失

self.optimizer_r1.zero_grad()
(value_loss_r1 + action_loss_r1 - dist_entropy_r1 * args.entropy_coef).backward()
self.optimizer_r1.step()
clip_grad_norm_(self.actor_critic_r1.parameters(), args.max_grad_norm)
self.optimizer_r2.zero_grad()
(value_loss_r2 + action_loss_r2 - dist_entropy_r2 * args.entropy_coef).backward()
self.optimizer_r2.step()
clip_grad_norm_(self.actor_critic_r2.parameters(), args.max_grad_norm)

我不确定这种正确的方法来更新具有多个损失的网络,请提供您的建议。

1 个答案:

答案 0 :(得分:1)

应该是求和方法。如果没有相互作用,那么对于“错误”优化器而言,“错误”损失的梯度将为零,并且如果存在 相互作用,则您可能希望针对该相互作用进行优化。

仅当您知道存在相互作用时,才可以使用方法2进行优化。