将pytorch模型转换为core-ml时出错

时间:2019-05-20 08:56:22

标签: pytorch onnx onnx-coreml

C = torch.cat((A,B),1)

张量的形状:

A is (1, 128, 128, 256)
B is (1, 1, 128, 256)

预期的C值为(1, 129, 128, 256)

此代码可在pytorch上使用,但是在转换为core-ml时会出现以下错误:

"Error while converting op of type: {}. Error message: {}\n".format(node.op_type, err_message, )
TypeError: Error while converting op of type: Concat. Error message: unable to translate constant array shape to CoreML shape"

1 个答案:

答案 0 :(得分:0)

这是与coremltools版本有关的问题。尝试使用最新的beta coremltools 3.0b2。

以下操作在最新的Beta中没有任何错误。

import torch

class cat_model(torch.nn.Module):
    def __init__(self):
        super(cat_model, self).__init__()

    def forward(self, a, b):
        c = torch.cat((a, b), 1)
        # print(c.shape)
        return c

a = torch.randn((1, 128, 128, 256))
b = torch.randn((1, 1, 128, 256))

model = cat_model()
torch.onnx.export(model, (a, b), 'cat_model.onnx')

import onnx
model = onnx.load('cat_model.onnx')
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

from onnx_coreml import convert
mlmodel = convert(model)