有什么方法可以将PyTorch中可用的预训练模型下载到特定路径?

时间:2018-10-03 13:35:08

标签: python pytorch torchvision

3 个答案:

答案 0 :(得分:5)

在加载预先训练的模型时,内部会调用@dennlinger中的answertorch.utils.model_zoo

更具体地说,每次加载预训练的模型时都会调用方法torch.utils.model_zoo.load_url()。相同的文档中提到:

  

model_dir的默认值为$TORCH_HOME/models,其中   $TORCH_HOME默认为~/.torch

     

默认目录可以用$TORCH_MODEL_ZOO覆盖   环境变量。

这可以如下进行:

import torch 
import torchvision
import os

# Suppose you are trying to load pre-trained resnet model in directory- models\resnet

os.environ['TORCH_MODEL_ZOO'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)

我在PyTorch的GitHub存储库中提出了一个问题,从而遇到了上述解决方案: https://github.com/pytorch/vision/issues/616

这导致了文档的改进,即上述解决方案。

答案 1 :(得分:4)

是的,您只需复制URL并使用wget即可将其下载到所需的路径。这是一个例子:

对于 AlexNet

$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth

对于 Google Inception(v3)

$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth

对于 SqueezeNet

$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth

如果要使用Python进行操作,请使用类似以下内容的

In [11]: from six.moves import urllib

# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"

# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)

P.S:您可以在torchvision.models

的单个python文件中找到所有下载URL。

答案 2 :(得分:1)

TL; DR:不,这不可能直接实现,但是您可以轻松地对其进行调整。

我认为您想要做的是看看torch.utils.model_zoo,当您加载预先训练的模型时会在内部调用它:

如果我们查看经过预训练的模型的代码,例如AlexNet here,我们会发现它只是调用了前面提到的model_zoo函数,但没有保存位置。您可以修改PyTorch源以指定它(这实际上是IMO的一个很好的补充,因此可以为此打开一个拉取请求),也可以根据自己的喜好采用第二个链接中的代码(并将其保存到自定义位置(使用其他名称),然后在此处手动插入相关位置。

如果您想定期更新PyTorch,我会强烈推荐第二种方法,因为它不涉及直接更改PyTorch的代码库,并且可能在更新过程中引发错误。