基本上,正如该线程讨论的here一样,您不能使用python list来包装子模块(例如,您的图层);否则,Pytorch不会更新列表中子模块的参数。相反,您应该使用nn.ModuleList
来包装子模块,以确保其参数将被更新。现在,我还看到了类似以下代码的代码,其中作者使用python列表计算损失,然后执行loss.backward()
进行更新(强化RL算法)。这是代码:
policy_loss = []
for log_prob in self.controller.log_probability_slected_action_list:
policy_loss.append(- log_prob * (average_reward - b))
self.optimizer.zero_grad()
final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
final_policy_loss.backward()
self.optimizer.step()
为什么使用这种格式的列表可以更新模块的参数,但第一种情况不起作用?我现在很困惑。如果我更改了先前的代码policy_loss = nn.ModuleList([])
,它将引发异常,表明张量浮点不是子模块。
答案 0 :(得分:2)
您误解了Module
是什么。 Module
存储参数并定义正向传递的实现。
允许您使用张量和参数执行任意计算,从而生成其他新张量。 Modules
不需要知道那些张量。您还可以将张量列表存储在Python列表中。调用backward
时,它必须在标量张量上,因此必须是串联的总和。这些张量是损失,而不是参数,因此它们不应是Module
的属性,也不应包装在ModuleList
中。