Pytorch RuntimeError:张量a(4)的大小必须与非单维度0上张量b(3)的大小匹配

时间:2019-10-22 03:53:50

标签: python computer-vision pytorch

我使用here中的代码来训练模型,以预测从09的印刷样式编号:

idx_to_class = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7:"7", 8: "8", 9:"9"}
def predict(model, test_image_name):

    transform = image_transforms['test']

    test_image = Image.open(test_image_name)
    plt.imshow(test_image)

    test_image_tensor = transform(test_image)

    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224)

    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        ps = torch.exp(out)
        topk, topclass = ps.topk(1, dim=1)
        # print(topclass.cpu().numpy()[0][0])
        print("Image class:  ", idx_to_class[topclass.cpu().numpy()[0][0]])

predict(model, "path_of_test_image")

但是尝试使用predict时出现错误:

Traceback (most recent call last):

  File "<ipython-input-12-f8636d3ba083>", line 26, in <module>
    predict(model, "/home/x/文档/Deep_Learning/pytorch/MNIST/test/2/QQ截图20191022093955.png")

  File "<ipython-input-12-f8636d3ba083>", line 9, in predict
    test_image_tensor = transform(test_image)

  File "/home/x/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)

  File "/home/x/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 166, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)

  File "/home/x/.local/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 217, in normalize
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

我该如何解决?谢谢。

1 个答案:

答案 0 :(得分:2)

我怀疑您的#include <stdio.h> #include <signal.h> #include <unistd.h> volatile sig_atomic_t exitSig = 0; volatile sig_atomic_t feedExitSig = 0; void signal_handler(int signum) { if (signum == SIGINT) { if (feedExitSig) /* only if feedExitSig already set */ exitSig = 1; /* set exitSig - requires 2nd ctrl+c */ feedExitSig = 1; } } int main (void) { signal (SIGINT, signal_handler); while (!exitSig){ puts ("outer"); while (!feedExitSig){ puts (" inner"); sleep(1); } sleep(1); } } 每个像素有一个附加的alpha通道,因此它只有4个通道,而只有3个。
试试:

test_image