使用Dev Pytorch 1.0将Pytorch模型加载到C ++中

时间:2018-10-01 21:29:48

标签: pytorch

Pytorch 1.0具有将模型转换为Torch脚本程序(以某种方式序列化)的功能,以使其能够在C ++中执行而没有Python依赖项。

详细信息在本教程中。 https://pytorch.org/tutorials/advanced/cpp_export.html

这是完成的方式:

import torch
import torchvision

# An instance of your model.
model = A UNET MODEL FROM FASTAI which has hooks as required by UNET

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

在我的用例中,我正在使用UNET模型进行语义分割。但是,我使用这种方法跟踪模型,出现以下错误。

Forward or backward hooks can't be compiled 

UNET模型使用挂钩保存中间功能,这些功能在网络的后续层中使用。有办法解决吗?或这仍然是此新方法的局限性,它不能与使用此类挂钩的模型一起使用。

2 个答案:

答案 0 :(得分:0)

如果可以从Pytorch集线器使用UNET模型。它将与TorchScript一起使用。

import torch

# downloading the model from torchhub
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

#  downloading the sample
import urllib
url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
    
# reading the sample and some prerequisites for transformation
import numpy as np
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)

m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
preprocess = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=m, std=s),])

input_tensor = preprocess(input_image)

input_batch = input_tensor.unsqueeze(0)

# creating the trace
traced_module = torch.jit.trace(model,input_batch)

# running the trace
traced_module(input_batch)

PS:torch.jit.trace / torch.jit.script都不支持所有的Torch功能,因此将它们与外部库一起使用总是很棘手。

答案 1 :(得分:-1)

也许您可以用C ++重写模型,因为c ++ API与python版本几乎具有相同的接口。