在使用torchvision.models
模块加载经过预训练的VGG网络并将其用于分类任意RGB图像时,网络的输出因调用而异。为什么会这样?据我了解,VGG前传的任何部分都不应该是不确定的。
这是MCVE:
import torch
from torchvision.models import vgg16
vgg = vgg16(pretrained=True)
img = torch.randn(1, 3, 256, 256)
torch.all(torch.eq(vgg(img), vgg(img))) # result is 0, but why?
答案 0 :(得分:1)
vgg16
有一个nn.Dropout
层,在训练期间,随机会丢弃其输入的50%。在测试期间,您应通过将网络模式设置为“评估”模式来“关闭”此行为:
vgg.eval()
torch.all(torch.eq(vgg(img), vgg(img)))
Out[73]: tensor(1, dtype=torch.uint8)
请注意,还有其他具有随机行为和不同行为的层可以进行训练和评估(例如BatchNorm)。因此,在评估经过训练的模型之前,必须切换到eval()
模式。