PyTorch:预先训练的VGG输出不一致

时间:2019-05-06 10:20:07

标签: python deep-learning computer-vision pytorch torchvision

在使用torchvision.models模块加载经过预训练的VGG网络并将其用于分类任意RGB图像时,网络的输出因调用而异。为什么会这样?据我了解,VGG前传的任何部分都不应该是不确定的。

这是MCVE:

import torch
from torchvision.models import vgg16

vgg = vgg16(pretrained=True)

img = torch.randn(1, 3, 256, 256)

torch.all(torch.eq(vgg(img), vgg(img))) # result is 0, but why?

1 个答案:

答案 0 :(得分:1)

vgg16有一个nn.Dropout层,在训练期间,随机会丢弃其输入的50%。在测试期间,您应通过将网络模式设置为“评估”模式来“关闭”此行为:

vgg.eval()
torch.all(torch.eq(vgg(img), vgg(img)))
Out[73]: tensor(1, dtype=torch.uint8)

请注意,还有其他具有随机行为和不同行为的层可以进行训练和评估(例如BatchNorm)。因此,在评估经过训练的模型之前,必须切换到eval()模式。