如何从预先训练的PyTorch模型(例如ResNet或VGG)的特定层中提取特征,而无需再次进行前向传递?
答案 0 :(得分:1)
您可以在所需的特定层上注册forward hook。像这样:
def some_specific_layer_hook(module, input_, output):
pass # the value is in 'output'
model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
model(some_input)
例如,要在ResNet中获取 res5c 输出,您可能想要使用nonlocal
变量(或在Python 2中为global
):
res5c_output = None
def res5c_hook(module, input_, output):
nonlocal res5c_output
res5c_output = output
resnet.layer4.register_forward_hook(res5c_hook)
resnet(some_input)
# Then, use `res5c_output`.
答案 1 :(得分:1)
接受的答案非常有帮助!我在这里发布了一个完整的例子(使用@bryant1410 描述的注册钩子),用于寻找工作解决方案的懒惰者:
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
def get_feat_vector(path_img, model):
'''
Input:
path_img: string, /path/to/image
model: a pretrained torch model
Output:
my_output: torch.tensor, output of avgpool layer
'''
input_image = Image.open(path_img)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
my_output = None
def my_hook(module_, input_, output_):
nonlocal my_output
my_output = output_
a_hook = model.avgpool.register_forward_hook(my_hook)
model(input_batch)
a_hook.remove()
return my_output
你有你的特征提取函数,只需使用下面的代码片段调用它即可从 resnet18.avgpool
层获取特征
model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)
答案 2 :(得分:0)
我正在尝试提取预训练模型的特定层的特征。基于bryant1410答案的同伴代码确实可以工作,但是template_feature_map的值已更改,我什么也没做。
模型的第6层的输出应为负值,如第一次打印(template_feature_map)所示。但是,应该在第二次打印中保留的负值(template_feature_map)更改为零,我不知道为什么。如果您知道这种机制,请告诉我如何保持负值。
vgg_feature = models.vgg13(pretrained=True).features
template_feature_map=None
def save_template_feature_map(module, input, output):
global template_feature_map
template_feature_map=output
print(template_feature_map)
template_handle = vgg_feature[5].register_forward_hook(save_template_feature_map)
vgg_feature(template[0])
print(template_feature_map)
两次打印的输出(template_feature_map):
tensor([[[[-5.7389e-01, -2.7154e+00, -4.0990e+00, ..., 4.1902e+00,
3.1757e+00, 2.2461e+00],
[-2.2217e+00, -4.3395e+00, -6.8158e+00, ..., -1.4454e+00,
9.8012e-01, -2.3653e+00],
[-4.1940e+00, -6.3235e+00, -6.8422e+00, ..., -2.8329e+00,
2.5570e+00, -2.7704e+00],
...,
[-3.3250e+00, 1.3792e-01, 5.4926e+00, ..., -4.1722e+00,
-6.1008e-01, -2.6037e+00],
[ 1.5377e+00, 6.0671e-01, 2.0974e+00, ..., 1.2441e+00,
1.5033e+00, -2.7246e+00],
[ 6.8857e-01, -3.5160e-02, 6.7858e-01, ..., 1.2052e+00,
1.4533e+00, -1.4160e+00]],
[[ 6.8798e-01, 1.6971e+00, 2.1629e+00, ..., 3.1701e-01,
8.5424e-01, 2.8768e+00],
[ 1.4013e+00, 2.7217e+00, 2.1476e+00, ..., 3.1156e+00,
4.4858e+00, 3.6936e+00],
[ 3.1807e+00, 2.2245e+00, 2.4665e+00, ..., 1.3838e+00,
1.0580e-02, -3.1445e-03],
...,
[-4.7298e+00, -3.3037e+00, -1.2982e+00, ..., 2.3266e-01,
6.7711e+00, 3.8166e+00],
[-4.7972e+00, -5.4591e+00, -2.5201e+00, ..., 3.7584e+00,
5.1524e+00, 2.3072e+00],
[-2.4306e+00, -2.8033e+00, -2.0912e+00, ..., 1.9888e+00,
2.0582e+00, 1.9266e+00]],
[[-4.4257e+00, -4.6331e+00, -3.3580e-03, ..., -8.2233e+00,
-7.4645e+00, -1.7361e+00],
[-4.5593e+00, -8.4195e+00, -8.8428e+00, ..., -6.7950e+00,
-1.4665e+01, -2.5335e+00],
[-2.3481e+00, -3.8543e+00, -3.5965e+00, ..., -1.5105e+00,
-1.6923e+01, -5.9852e+00],
...,
[-8.0165e+00, 8.0185e+00, 6.5506e+00, ..., 5.3241e+00,
3.3854e+00, -1.6342e+00],
[-1.3689e+01, -2.2930e+00, 4.7097e+00, ..., 3.2021e+00,
2.9208e+00, -8.0228e-01],
[-1.3055e+01, -1.1470e+01, -8.4442e+00, ..., 1.8155e-02,
-6.2866e-02, -2.0333e+00]],
...,
[[ 3.4622e+00, -1.2417e+00, -5.0749e+00, ..., 5.3184e+00,
1.4744e+01, 8.3968e+00],
[-2.7820e+00, -9.1911e+00, -1.1069e+01, ..., 2.5380e+00,
9.8336e+00, 4.0623e+00],
[-3.9794e+00, -1.0140e+01, -9.9133e+00, ..., 3.0999e+00,
5.5936e+00, 2.5775e+00],
...,
[ 2.0299e+00, 2.1304e-01, -2.2307e+00, ..., 1.1388e+01,
8.8098e+00, 1.8991e+00],
[ 8.0663e-01, -1.5073e+00, 3.3977e-01, ..., 8.5316e+00,
4.9923e+00, -3.6818e-01],
[-3.5146e+00, -7.2647e+00, -5.4331e+00, ..., -1.9781e+00,
-3.4463e+00, -4.9034e+00]],
[[-3.2915e+00, -7.3263e+00, -6.8458e+00, ..., 2.3122e+00,
9.7774e-01, -1.3498e+00],
[-4.5396e+00, -8.6832e+00, -8.8582e+00, ..., 7.1535e-02,
-4.1133e+00, -4.4045e+00],
[-4.8781e+00, -7.0239e+00, -4.7350e+00, ..., -3.6954e+00,
-9.6687e+00, -8.8289e+00],
...,
[-4.7072e+00, -4.4823e-01, 1.7099e+00, ..., 3.7923e+00,
1.6887e+00, -4.3305e+00],
[-5.5120e+00, -3.2324e+00, 2.3594e+00, ..., 4.6031e+00,
1.8856e+00, -4.0147e+00],
[-5.1355e+00, -5.5335e+00, -1.7738e+00, ..., 1.6159e+00,
-1.3950e+00, -4.1055e+00]],
[[-2.0252e+00, -2.3971e+00, -1.6477e+00, ..., -3.3740e+00,
-4.9965e+00, -2.1219e+00],
[-7.6059e-01, -3.3901e-01, -1.8980e-01, ..., -4.3286e+00,
-7.1350e+00, -3.9186e+00],
[ 8.4101e-01, 1.3403e+00, 2.5821e-01, ..., -5.1847e+00,
-7.1829e+00, -3.7724e+00],
...,
[-6.0619e+00, -5.6475e+00, -1.6446e+00, ..., -9.2322e+00,
-9.1981e+00, -5.5239e+00],
[-7.4606e+00, -7.6054e+00, -5.8401e+00, ..., -7.6998e+00,
-6.4111e+00, -2.9374e+00],
[-6.4147e+00, -7.2813e+00, -6.1880e+00, ..., -4.6726e+00,
-3.1090e+00, -7.8383e-01]]]], grad_fn=<MkldnnConvolutionBackward>)
tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 4.1902e+00,
3.1757e+00, 2.2461e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
9.8012e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
2.5570e+00, 0.0000e+00],
...,
[0.0000e+00, 1.3792e-01, 5.4926e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.5377e+00, 6.0671e-01, 2.0974e+00, ..., 1.2441e+00,
1.5033e+00, 0.0000e+00],
[6.8857e-01, 0.0000e+00, 6.7858e-01, ..., 1.2052e+00,
1.4533e+00, 0.0000e+00]],
[[6.8798e-01, 1.6971e+00, 2.1629e+00, ..., 3.1701e-01,
8.5424e-01, 2.8768e+00],
[1.4013e+00, 2.7217e+00, 2.1476e+00, ..., 3.1156e+00,
4.4858e+00, 3.6936e+00],
[3.1807e+00, 2.2245e+00, 2.4665e+00, ..., 1.3838e+00,
1.0580e-02, 0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.3266e-01,
6.7711e+00, 3.8166e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.7584e+00,
5.1524e+00, 2.3072e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.9888e+00,
2.0582e+00, 1.9266e+00]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 8.0185e+00, 6.5506e+00, ..., 5.3241e+00,
3.3854e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 4.7097e+00, ..., 3.2021e+00,
2.9208e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.8155e-02,
0.0000e+00, 0.0000e+00]],
...,
[[3.4622e+00, 0.0000e+00, 0.0000e+00, ..., 5.3184e+00,
1.4744e+01, 8.3968e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.5380e+00,
9.8336e+00, 4.0623e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0999e+00,
5.5936e+00, 2.5775e+00],
...,
[2.0299e+00, 2.1304e-01, 0.0000e+00, ..., 1.1388e+01,
8.8098e+00, 1.8991e+00],
[8.0663e-01, 0.0000e+00, 3.3977e-01, ..., 8.5316e+00,
4.9923e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.3122e+00,
9.7774e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 7.1535e-02,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 1.7099e+00, ..., 3.7923e+00,
1.6887e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 2.3594e+00, ..., 4.6031e+00,
1.8856e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.6159e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[8.4101e-01, 1.3403e+00, 2.5821e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]]], grad_fn=<ThresholdBackward1>)