如果我希望它可以通过OpenCV dnn模块加载,我该如何保存PyTorch的模型

时间:2017-08-29 01:53:28

标签: python c++ opencv deep-learning pytorch

我通过PyTorch训练一个简单的分类模型并通过opencv3.3加载它,但它抛出异常然后说

  

OpenCV错误:readObject,file中未实现函数/功能(不支持的Lua类型)   /home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp,   797行   /home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp:797:   错误:(-213)函数readObject中不支持的Lua类型

模型定义

class conv_block(nn.Module):
    def __init__(self, in_filter, out_filter, kernel):
        super(conv_block, self).__init__()

        self.conv1 = nn.Conv2d(in_filter, out_filter, kernel, 1, (kernel - 1)//2)
        self.batchnorm = nn.BatchNorm2d(out_filter)
        self.maxpool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = F.relu(x)
        x = self.maxpool(x)

        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = conv_block(3, 6, 3)
        self.conv2 = conv_block(6, 16, 3)
        self.fc1 = nn.Linear(16 * 8 * 8, 120)
        self.bn1 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn2 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

此模型仅使用Conv2d,ReLU,BatchNorm2d,MaxPool2d和线性图层,opencv3.3支持每个图层

我通过state_dict

保存它
torch.save(net.state_dict(), 'cifar10_model')

用c ++加载

std::string const model_file("/home/some_folder/cifar10_model");

std::cout<<"read net from torch"<<std::endl;
dnn::Net net = dnn::readNetFromTorch(model_file);

我想我用错误的方式保存模型,保存PyTorch模型以便使用OpenCV加载的正确方法是什么?感谢

编辑:

我使用另一种方法保存模型,但无法加载

torch.save(net, 'cifar10_model.net')

这是一个错误吗?或者我做错了什么?

1 个答案:

答案 0 :(得分:2)

我找到答案,opencv3.3不支持PyTorch(https://github.com/pytorch/pytorch)但是pytorch(https://github.com/hughperkins/pytorch),这是一个很大的惊喜,我不知道还有另一个版本的pytorch存在(看起来就像一个死的项目,很长一段时间没有更新),我希望他们可以提一下他们在维基上支持哪个pytorch。