了解何时在Pytorch中使用python列表

时间:2019-03-15 18:24:40

标签: deep-learning pytorch backpropagation

基本上,正如该线程讨论的here一样,您不能使用python list来包装子模块(例如,您的图层);否则,Pytorch不会更新列表中子模块的参数。相反,您应该使用nn.ModuleList来包装子模块,以确保其参数将被更新。现在,我还看到了类似以下代码的代码,其中作者使用python列表计算损失,然后执行loss.backward()进行更新(强化RL算法)。这是代码:

 policy_loss = []
    for log_prob in self.controller.log_probability_slected_action_list:
        policy_loss.append(- log_prob * (average_reward - b))
    self.optimizer.zero_grad()
    final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
    final_policy_loss.backward()
    self.optimizer.step()

为什么使用这种格式的列表可以更新模块的参数,但第一种情况不起作用?我现在很困惑。如果我更改了先前的代码policy_loss = nn.ModuleList([]),它将引发异常,表明张量浮点不是子模块。

1 个答案:

答案 0 :(得分:2)

您误解了Module是什么。 Module存储参数并定义正向传递的实现。

允许您使用张量和参数执行任意计算,从而生成其他新张量。 Modules不需要知道那些张量。您还可以将张量列表存储在Python列表中。调用backward时,它必须在标量张量上,因此必须是串联的总和。这些张量是损失,而不是参数,因此它们不应是Module的属性,也不应包装在ModuleList中。