未连接互联网时如何缓存Pytorch模型以供使用?

时间:2020-02-20 04:08:00

标签: python pytorch vgg-net

我在分类问题中使用vgg19。我可以使用校园研究计算机进行培训,但是完成计算的节点无法访问互联网。因此,运行诸如self.net = models.vgg19(pretrained=True)之类的代码行会失败,并显示错误urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>

有没有办法将模型缓存在头节点(可以访问互联网的地方)上,并从缓存而不是计算节点上的Internet加载模型?

1 个答案:

答案 0 :(得分:2)

如果仅将经过预训练的网络的权重保存在某个位置,则可以像加载任何其他网络权重一样加载它们。

保存:

import torchvision

#  I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")

正在加载:

import torchvision

def create_vgg16(dict_path=None):
    model = torchvision.models.vgg16(pretrained=False)
    if (dict_path != None):
        model.load_state_dict(torch.load(dict_path))
    return model

model = create_vgg16("Somewhere")