在进行Udacity的AI课程时,我遇到了标题错误。
一开始看起来好像没有为模型格式化图像,但是本课程的说明中我们使用了调整大小来格式化图像,所以我得到的图像尺寸为3x224x224
,即3个颜色通道, 224 x 224像素。图像阵列的尺寸为[3,224,224]
。
Return/Output:
(3, 224, 224)
(3, 224, 224)
torch.Size([1, 3, 224, 224])
接下来,我发现它是模型,但是当我返回时,在训练循环之前并更改原始模型以修复此错误,我在反转“上游”时遇到了类似的错误。
我希望此predict
函数使用我的模型来预测图像的类别。
发布的是:
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(25088, 500)),
('relu', nn.ReLU()),
('drop', nn.Dropout(0.25)),
('fc2', nn.Linear(500, 102)),
('output', nn.LogSoftmax(dim=1))
]))
model.load_state_dict(state_dict)
checkpoint= {'input_size': 672,
'output_size': 102,
'hidden_layers': 224,
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')
def predict(image_path, model, topk=5):
''' Predict the class (or classes) of an image using a trained deep learning model.
'''
model= model.cuda()
model.eval()
# TODO: Implement the code to predict the class from an image file
pil_img= Image.open(image_path)
#print(pil_img)
processed_image= process_image(pil_img)
print(processed_image.shape)
torch_image= torch.from_numpy(processed_image)
torch_image= torch_image.unsqueeze_(0)
print(torch_image.shape)
torch_image= torch_image.float().to('cuda')
output= model.forward(torch_image)
#top_k= predict.topk(topk)
#print(top_k)
return torch_img
image_path = 'flowers/test/8/image_03299.jpg'
predict(image_path, model, 5)
> RuntimeError Traceback (most recent call last)
<ipython-input-21-d84c5999ad68> in <module>()
23
24 image_path = 'flowers/test/8/image_03299.jpg'
---> 25 predict(image_path, model, 5)
<ipython-input-21-d84c5999ad68> in predict(image_path, model, topk)
16 print(torch_image.shape)
17 torch_image= torch_image.float().to('cuda')
---> 18 output= model.forward(torch_image)
19 #top_k= predict.topk(topk)
20 #print(top_k)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
89 def forward(self, input):
90 for module in self._modules.values():
---> 91 input = module(input)
92 return input
93
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
489 result = self._slow_forward(*input, **kwargs)
490 else:
--> 491 result = self.forward(*input, **kwargs)
492 for hook in self._forward_hooks.values():
493 hook_result = hook(self, input, result)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
53
54 def forward(self, input):
---> 55 return F.linear(input, self.weight, self.bias)
56
57 def extra_repr(self):
/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
992 return torch.addmm(bias, input, weight.t())
993
--> 994 output = input.matmul(weight.t())
995 if bias is not None:
996 output += bias
RuntimeError Traceback (most recent call last)
<ipython-input-21-d84c5999ad68> in <module>()
23
24 image_path = 'flowers/test/8/image_03299.jpg'
---> 25 predict(image_path, model, 5)
<ipython-input-21-d84c5999ad68> in predict(image_path, model, topk)
16 print(torch_image.shape)
17 torch_image= torch_image.float().to('cuda')
---> 18 output= model.forward(torch_image)
19 #top_k= predict.topk(topk)
20 #print(top_k)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
89 def forward(self, input):
90 for module in self._modules.values():
---> 91 input = module(input)
92 return input
93
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
489 result = self._slow_forward(*input, **kwargs)
490 else:
--> 491 result = self.forward(*input, **kwargs)
492 for hook in self._forward_hooks.values():
493 hook_result = hook(self, input, result)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
53
54 def forward(self, input):
---> 55 return F.linear(input, self.weight, self.bias)
56
57 def extra_repr(self):
/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
992 return torch.addmm(bias, input, weight.t())
993
--> 994 output = input.matmul(weight.t())
995 if bias is not None:
996 output += bias
RuntimeError: size mismatch, m1: [672 x 224], m2: [672 x 224] at /opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/THC/generic/THCTensorMathBlas.cu:249