我在分类问题中使用vgg19。我可以使用校园研究计算机进行培训,但是完成计算的节点无法访问互联网。因此,运行诸如self.net = models.vgg19(pretrained=True)
之类的代码行会失败,并显示错误urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>
有没有办法将模型缓存在头节点(可以访问互联网的地方)上,并从缓存而不是计算节点上的Internet加载模型?
答案 0 :(得分:2)
如果仅将经过预训练的网络的权重保存在某个位置,则可以像加载任何其他网络权重一样加载它们。
保存:
import torchvision
# I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")
正在加载:
import torchvision
def create_vgg16(dict_path=None):
model = torchvision.models.vgg16(pretrained=False)
if (dict_path != None):
model.load_state_dict(torch.load(dict_path))
return model
model = create_vgg16("Somewhere")