从Pytorch到ONNX导出功能失败,并导致旧版功能错误

时间:2019-10-22 17:20:58

标签: pytorch onnx

我正在尝试使用以下代码将this链接中的pytorch模型转换为onnx模型:

device=t.device('cuda:0' if t.cuda.is_available() else 'cpu')
print(device)

faster_rcnn = FasterRCNNVGG16()
trainer = FasterRCNNTrainer(faster_rcnn).cuda()
#trainer = FasterRCNNTrainer(faster_rcnn).to(device)
trainer.load('./checkpoints/model.pth')

dummy_input = t.randn(1, 3, 300, 300, device = 'cuda')
#dummy_input = dummy_input.to(device)
t.onnx.export(faster_rcnn, dummy_input, "model.onnx", verbose = True)

但是我收到以下错误消息(很抱歉,在stackoverflow下面的块引用不会让整个跟踪成为代码格式,也不会让问题被发布):

  Traceback (most recent call last):
     small_object_detection_master_samirsen\onnxtest.py", line 44, in <module>
       t.onnx.export(faster_rcnn, dummy_input, "fasterrcnn_10120119_06025842847785781.onnx", verbose = True)
     File "C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\onnx\__init__.py",
     

第132行,在导出中              strip_doc_string,dynamic_axes)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ onnx \ utils.py”,   第64行,在导出中              example_outputs = example_outputs,strip_doc_string = strip_doc_string,dynamic_axes = dynamic_axes)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ onnx \ utils.py”,   _export中的第329行              _retain_param_name,do_constant_folding)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ onnx \ utils.py”,   _model_to_graph中的第213行              图,火炬输出= _trace_and_get_graph_from_model(模型,参数,训练)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ onnx \ utils.py”,   _trace_and_get_graph_from_model中的第171行              跟踪,torch_out = torch.jit.get_trace_graph(模型,参数,_force_outplace = True)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ jit__init __。py”,   get_trace_graph中的第256行              return LegacyTracedModule(f,_force_outplace,return_inputs)(* args,** kwargs)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   第547行,在致电中              结果= self.forward(* input,** kwargs)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ jit__init __。py”,   323行,向前              out = self.inner(* trace_inputs)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   第545行,致电              结果= self._slow_forward(*输入,**扭曲)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   _slow_forward中的第531行           文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   _slow_forward中的第531行              结果= self.forward(* input,** kwargs)            文件“ D:\ smallobject2 \ export test s \ small_object_detection_master_samirsen \ model \ faster_rcnn.py”,行   133,向前              h,rois,roi_indices)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   第545行,致电              结果= self._slow_forward(*输入,**扭曲)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   _slow_forward中的第531行              结果= self.forward(* input,** kwargs)            文件“ D:\ smallobject2 \ export test s \ small_object_detection_master_samirsen \ model \ faster_rcnn_vgg16.py”,   142行,向前              池= self.roi(x,indexs_and_rois)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   第545行,致电              结果= self._slow_forward(*输入,**扭曲)            文件“ C:\ Users \ HP \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ torch \ nn \ modules \ module.py”,   _slow_forward中的第531行              结果= self.forward(* input,** kwargs)            文件“ D:\ smallobject2 \ export test s \ small_object_detection_master_samirsen \ model \ roi_module.py”,行   85,向前              返回self.RoI(x,rois)           RuntimeError:尝试跟踪RoI,但不支持跟踪旧功能

1 个答案:

答案 0 :(得分:-1)

这是因为 ONNX 不支持 torch.grad.Function。问题是因为 ROI 类 Refer this

要解决这个问题,您必须将前向和后向函数实现为单独的函数定义,而不是 ROI 类的成员。 FasterRCNNVGG16 中对 ROI 的函数调用应该更改为显式调用前向和后向函数。