我需要将火炬模型转换为pytorch。由于火炬模型具有pytorch不支持的层(例如初始和LRN),因此无法使用内置API。为了将这些模型从火炬转换为pytorch,有必要在pytorch中实现这些层,并将火炬模型中的所有参数保存为hdf5文件,并将它们作为字典重新加载到python。我是lua的新手,我想问一下如何获取这个昵称'火炬中的所有参数。
顺便说一句,这可以在pytorch中轻松完成,例如:
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=32,kernel_size=7,stride=1,bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=32,affine=True),
nn.MaxPool2d(kernel_size=2,stride=2)
)
for key in model.state_dict():
value = model.state_dict().get(key)
print(key, value.size())
如果所有参数都可以以字典格式访问,则可以使用以下代码在pytorch中重建模型:
model = MyNewInceptionModel()
model.load_state_dict(param_dict)