我正在编写pytorch模型的C版本,以便在我的特殊硬件上运行它。 到目前为止,一切正常,除了每个batchnorm层中的running_mean和running_var。
我们有一个python代码来转储所有named_parameters,但对于running_stats则无济于事,尽管我们需要在转发计算中使用它。
那么有没有一种带有内置功能的转储方法? 我搜索了pytorch doc,对我的任务没有帮助。 否则,我可能需要编写一个regexp代码来识别并转储它们。
非常感谢。 /帕特里克
for name, param in model.named_parameters():
# here can dump weight and bias, but not running_stats
names.append(name)
shapes.append(list(param.data.numpy().shape))
values.append(param.data.numpy().flatten().tolist())
答案 0 :(得分:1)
running_mean
和其他registered_buffers
在PyTorch中。您可以使用torch.nn.Module
的{{3}}保存(如您所说的转储):
torch.save(model.state_dict(), PATH)
您可以遍历命名缓冲区并保存每个缓冲区,但是您希望类似于参数:
for name, buffer in model.named_buffers():
# do your thing with them