如何获取未知PyTorch模型的输入张量形状

时间:2020-07-28 09:42:16

标签: python machine-learning deep-learning pytorch onnx

我正在编写一个python脚本,该脚本将任何流行的框架(TensorFlow,Keras,PyTorch)的深度学习模型转换为ONNX格式。目前,我已经将tf2onnx用于张量流,并将keras2onnx用于keras到ONNX的转换。

现在PyTorch集成了ONNX支持,因此我可以直接从PyTorch保存ONNX模型。但是问题是我需要为该模型输入张量形状,以便将其保存为ONNX格式。您可能已经猜到了,我正在编写此脚本来转换未知的深度学习模型。

Here是PyTorch的ONNX转换教程。上面写着:

限制¶ ONNX导出器是基于跟踪的导出器,这意味着它通过执行一次模型并导出在此运行期间实际运行的运算符进行操作。这意味着,如果您的模型是动态的(例如,根据输入数据更改行为),则导出将不准确。

类似地,跟踪可能仅对特定的输入大小才有效(这是我们在跟踪时需要显式输入的原因之一)。大多数操作员会导出与尺寸无关的版本,并且应在不同的批次尺寸或输入尺寸下工作。我们建议检查模型跟踪,并确保跟踪的运算符看起来合理。


我正在使用的代码段是这样:

dist
    css
        admin.css
        pages.css
    js
        admin-bundle.js
        pages-bundle.js

那么我怎么知道那个未知的PyTorch模型的输入张量的INPUT_SHAPE?还是有其他方法可以将PyTorch模型转换为ONNX?

1 个答案:

答案 0 :(得分:3)

您可以以此为起点进行调试

rm <path-to-your-env>/bin/python

然后得到N,C并通过将H,W专门设置为None来创建一个张量,就像这个玩具示例一样

cp /usr/bin/python <path-to-your-env>/bin/python