在火炬中访问参数名称

时间:2017-09-08 04:56:15

标签: torch pytorch

我需要将火炬模型转换为pytorch。由于火炬模型具有pytorch不支持的层(例如初始和LRN),因此无法使用内置API。为了将这些模型从火炬转换为pytorch,有必要在pytorch中实现这些层,并将火炬模型中的所有参数保存为hdf5文件,并将它们作为字典重新加载到python。我是lua的新手,我想问一下如何获取这个昵称'火炬中的所有参数。

顺便说一句,这可以在pytorch中轻松完成,例如:

import torch.nn as nn
model = nn.Sequential(
                nn.Conv2d(in_channels=3,out_channels=32,kernel_size=7,stride=1,bias=False),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(num_features=32,affine=True),
                nn.MaxPool2d(kernel_size=2,stride=2)
                )
for key in model.state_dict():
    value = model.state_dict().get(key)
    print(key, value.size())

如果所有参数都可以以字典格式访问,则可以使用以下代码在pytorch中重建模型:

model = MyNewInceptionModel()
model.load_state_dict(param_dict)

0 个答案:

没有答案