我已将我的PyTorch模型导出到ONNX。现在,我有办法从该ONNX模型获取输入层吗?
将PyTorch模型导出到ONNX
import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")
加载ONNX模型
onnx_model = onnx.load('onnx_model.onnx')
我希望能够以某种方式从onnx_model获取输入层。这可能吗?
答案 0 :(得分:0)
ONNX模型是一个protobuf结构,如此处(https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto)所定义。您可以使用为python生成的标准protobuf方法来使用它(请参见:https://developers.google.com/protocol-buffers/docs/reference/python-generated)。我不明白您要提取什么。但是您可以遍历组成图的节点( model.graph.node )。图中的第一个节点可能与您可能认为的第一层相对应(取决于翻译的完成方式),也可能不对应。您还可以获取模型的输入( model.graph.input )。