为了在pytorch中访问模型的参数,我看到了两种方法:
using state_dict
和using parameters()
我想知道两者之间有什么区别,或者一个是好的习惯,另一个是不好的习惯。
谢谢
答案 0 :(得分:4)
parameters()
仅提供模块参数,即权重和偏差。
返回模块参数上的迭代器。
您可以按以下方式检查参数列表:
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
另一方面,state_dict
返回一个包含模块整个状态的字典。检查其source code
,其中不仅包含对parameters
的调用,还包含对buffers
等的调用。
包括参数和持久缓冲区(例如运行平均值)。键是相应的参数和缓冲区名称。
使用以下方法检查state_dict
包含的所有键:
model.state_dict().keys()
例如,在state_dict
中,您会找到bn1.running_mean
中不存在的诸如running_var
和.parameters()
之类的条目。
如果您只想访问参数,则可以简单地使用.parameters()
,而出于传输学习中保存和加载模型的目的,您不仅需要保存参数,还需要保存state_dict
。 / p>
答案 1 :(得分:2)
除了@kHarshit答案不同外,requires_grad
中可训练张量的属性net.parameters()
为True
,而False
中的net.state_dict()
>