加载预训练的 pytorch 模型

时间:2020-12-29 06:15:51

标签: python pytorch torch

虽然问题看起来很简单,但在别处无法找到此解决方案。

我有一个 pytorch (.pt) 文件,我正在尝试加载它。我知道我需要先通过做来构建模型

model = MyModel()

但是我的 pytorch 文件构建了一个模型 (se_resnext101_32x4d),我没有为其创建一个类。因此,当我尝试做

model = se_resnext101_32x4d()

出现错误

name 'se_resnext101_32x4d' is not defined

我尝试过

import pretrainedmodels

model = pretrainedmodels.__dict__[se_resnext101_32x4d]()

但错误仍然存​​在。

1 个答案:

答案 0 :(得分:0)

稍加搜索后,您似乎正在尝试使用包含预训练模型和 API 的 this package 来下载和使用它们。根据 their documentation,您可以像这样加载模型:

import pretrainedmodels

Model = pretrainedmodels.__dict__['se_resnext101_32x4d']
model = Model(num_classes=1000, pretrained='imagenet')
model.eval()

如果您还没有安装 pip 包,请不要忘记安装:

pip install pretrainedmodels