ret中的错误=火炬._C._nn.nll_loss2d(输入,目标,重量,_Reduction.get_enum(reduction),ignore_index)

时间:2019-10-06 14:31:53

标签: pytorch

输入图像大小为512 * 512,以适合resnet的输入。input image 我用

_img = Image.open(self.images[index]).convert('RGB')

在数据加载器中。 我将resnet50用作没有fc的网络骨干。输出形状为

[4,2048,16,16]

然后使用了两个(conv bn relu)和一个插值

    def forward(self, input):
        x=self.backbone(input)
        x = self.conv1(x)
        x= self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x= self.bn2(x)
        x = self.relu(x)
        x = F.interpolate(x, size=[512,512], mode='bilinear', align_corners=True)
        return x

训练的一部分

    self.criterion=nn.CrossEntropyLoss()
    if self.args.cuda:
        image, target = image.cuda(), target.cuda()
    self.scheduler(self.optimizer, i, epoch, self.best_pred)
    self.optimizer.zero_grad()
    output = self.model(image)
    loss = self.criterion(output, target.long())
    loss.backward()

但是发生错误

File "E:/python_workspace/1006/train.py", line 135, in training
loss = self.criterion(output, target.long())
File "E:\python_workspace\1006\utils\loss.py", line 28, in CrossEntropyLoss
loss = criterion(logit, target.long())
File "E:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "E:\anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 916, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1995, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1826, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at C:\w\1\s\tmp_conda_3.6_045031\conda\conda-bld\pytorch_1565412750030\work\aten\src\THNN/generic/SpatialClassNLLCriterion.c:111

image.shape is [4, 3, 512, 512],dtype is torch.float32
target.shape is [4, 512, 512],dtype is torch.float32
output.shape is [4, 3, 512, 512],dtype is torch.float32

target image 目标图像只有三种不同的颜色,所以我将输出设置为3通道,并且图像模式为P 我的代码哪里可能有问题?

1 个答案:

答案 0 :(得分:1)

根据您的batch_size=4的大小来判断。您正在尝试预测每个像素三个标签之一,即n_classes=3

您收到的错误:

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.

意味着您提供给损失函数的target.long()的值可以为负,或者大于n_classes

检查阅读地面真相标签的方式。如果它是P类型的图像,则需要照原样读取,而不是将其转换为RGB值。

PS,
不要align_corners=True中使用F.interpolate,这会导致失真。