PyTorch:state_dict和parameters()有什么区别?

时间:2019-02-18 11:58:35

标签: python machine-learning deep-learning pytorch

为了在pytorch中访问模型的参数,我看到了两种方法:

using state_dictusing parameters()

我想知道两者之间有什么区别,或者一个是好的习惯,另一个是不好的习惯。

谢谢

2 个答案:

答案 0 :(得分:4)

parameters()仅提供模块参数,即权重和偏差。

  

返回模块参数上的迭代器。

您可以按以下方式检查参数列表:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

另一方面,state_dict返回一个包含模块整个状态的字典。检查其source code,其中不仅包含对parameters的调用,还包含对buffers等的调用。

  

包括参数和持久缓冲区(例如运行平均值)。键是相应的参数和缓冲区名称。

使用以下方法检查state_dict包含的所有键:

model.state_dict().keys()

例如,在state_dict中,您会找到bn1.running_mean中不存在的诸如running_var.parameters()之类的条目。


如果您只想访问参数,则可以简单地使用.parameters(),而出于传输学习中保存和加载模型的目的,您不仅需要保存参数,还需要保存state_dict。 / p>

答案 1 :(得分:2)

除了@kHarshit答案不同外,requires_grad中可训练张量的属性net.parameters()True,而False中的net.state_dict()