如何加载预训练的 pytorch 权重

时间:2021-05-17 08:59:02

标签: python pytorch onnx

我正在关注此博客 https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html 想在 onnx 运行时运行 pytorch 模型。在示例中,它给了一个预训练权重一个 URL 如何从本地磁盘加载一个预训练权重。

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
path = "/content/best.pt"
batch_size = 1    # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# set the model to inference mode
torch_model.eval()

我想加载定义为 Path 的权重。

1 个答案:

答案 0 :(得分:0)

如果你想从路径加载状态字典,你应该这样做:

torch_model.load_state_dict(torch.load(path))

这应该有效。