如何将.pth模型转换为.pb文件?

时间:2019-12-23 04:45:41

标签: tensorflow pytorch

我已经使用pytorch获得了完整的模型,但是我想将.pth文件转换为.pb,可以在Tensorflow中使用。有人有想法吗?

1 个答案:

答案 0 :(得分:0)

您可以使用ONNX:开放式神经网络交换格式

要将.pth文件转换为.pb,首先,需要将PyTorch中定义的模型导出到ONNX,然后将ONNX模型导入Tensorflow(PyTorch => ONNX => Tensorflow)

这是从Convert a PyTorch model to Tensorflow using ONNXonnx/tutorials的MNISTModel的示例

将训练好的模型保存到文件中

torch.save(model.state_dict(), 'output/mnist.pth')

从文件中加载经过训练的模型

trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")

加载ONNX文件

model = onnx.load('output/mnist.onnx')

# Import the ONNX model to Tensorflow
tf_rep = prepare(model)

将Tensorflow模型保存到文件中

tf_rep.export_graph('output/mnist.pb')