是否有方便的方法来为pytorch模型转储running_stats?

时间:2019-09-22 14:12:52

标签: pytorch

我正在编写pytorch模型的C版本,以便在我的特殊硬件上运行它。 到目前为止,一切正常,除了每个batchnorm层中的running_mean和running_var。

我们有一个python代码来转储所有named_pa​​rameters,但对于running_stats则无济于事,尽管我们需要在转发计算中使用它。

那么有没有一种带有内置功能的转储方法? 我搜索了pytorch doc,对我的任务没有帮助。 否则,我可能需要编写一个regexp代码来识别并转储它们。

非常感谢。 /帕特里克

for name, param in model.named_parameters():
    # here can dump weight and bias, but not running_stats
    names.append(name)
    shapes.append(list(param.data.numpy().shape))
    values.append(param.data.numpy().flatten().tolist())

1 个答案:

答案 0 :(得分:1)

running_mean和其他registered_buffers在PyTorch中。您可以使用torch.nn.Module的{​​{3}}保存(如您所说的转储):

torch.save(model.state_dict(), PATH) 

您可以遍历命名缓冲区并保存每个缓冲区,但是您希望类似于参数:

for name, buffer in model.named_buffers():
    # do your thing with them