输入图像大小为512 * 512,以适合resnet的输入。input image 我用
_img = Image.open(self.images[index]).convert('RGB')
在数据加载器中。 我将resnet50用作没有fc的网络骨干。输出形状为
[4,2048,16,16]
然后使用了两个(conv bn relu)和一个插值
def forward(self, input):
x=self.backbone(input)
x = self.conv1(x)
x= self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x= self.bn2(x)
x = self.relu(x)
x = F.interpolate(x, size=[512,512], mode='bilinear', align_corners=True)
return x
训练的一部分
self.criterion=nn.CrossEntropyLoss()
if self.args.cuda:
image, target = image.cuda(), target.cuda()
self.scheduler(self.optimizer, i, epoch, self.best_pred)
self.optimizer.zero_grad()
output = self.model(image)
loss = self.criterion(output, target.long())
loss.backward()
但是发生错误
File "E:/python_workspace/1006/train.py", line 135, in training
loss = self.criterion(output, target.long())
File "E:\python_workspace\1006\utils\loss.py", line 28, in CrossEntropyLoss
loss = criterion(logit, target.long())
File "E:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "E:\anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 916, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1995, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1826, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at C:\w\1\s\tmp_conda_3.6_045031\conda\conda-bld\pytorch_1565412750030\work\aten\src\THNN/generic/SpatialClassNLLCriterion.c:111
image.shape is [4, 3, 512, 512],dtype is torch.float32
target.shape is [4, 512, 512],dtype is torch.float32
output.shape is [4, 3, 512, 512],dtype is torch.float32
target image 目标图像只有三种不同的颜色,所以我将输出设置为3通道,并且图像模式为P 我的代码哪里可能有问题?
答案 0 :(得分:1)
根据您的batch_size=4
的大小来判断。您正在尝试预测每个像素三个标签之一,即n_classes=3
。
您收到的错误:
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.
意味着您提供给损失函数的target.long()
的值可以为负,或者大于n_classes
。
检查阅读地面真相标签的方式。如果它是P
类型的图像,则需要照原样读取,而不是将其转换为RGB值。
PS,
不要在align_corners=True
中使用F.interpolate
,这会导致失真。