3D火炬中的骰子损失

时间:2019-06-08 16:16:49

标签: deep-learning pytorch

我正在尝试将骰子损失与unet模型集成在一起,骰子是损失是从其他任务中借来的。这就是它的样子

class GeneralizedDiceLoss(nn.Module):
    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf
    """

    def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True):
        super(GeneralizedDiceLoss, self).__init__()
        self.epsilon = epsilon
        self.register_buffer('weight', weight)
        self.ignore_index = ignore_index
        if sigmoid_normalization:
            self.normalization = nn.Sigmoid()
        else:
            self.normalization = nn.Softmax(dim=1)

    def forward(self, input, target):
        # get probabilities from logits
        input = input.float()
        input = self.normalization(input)

        assert input.size() == target.size(), "'input' and 'target' must have the same shape"

        # mask ignore_index if present
        if self.ignore_index is not None:
            mask = target.clone().ne_(self.ignore_index)
            mask.requires_grad = False

            input = input * mask
            target = target * mask

        input = flatten(input)
        target = flatten(target)

        target = target.float()
        target_sum = target.sum(-1)
        class_weights = Variable(1. / (target_sum * target_sum).clamp(min=self.epsilon), requires_grad=False)

        intersect = (input * target).sum(-1) * class_weights
        if self.weight is not None:
            weight = Variable(self.weight, requires_grad=False)
            intersect = weight * intersect
        intersect = intersect.sum()

        denominator = ((input + target).sum(-1) * class_weights).sum()

        return 1. - 2. * intersect / denominator.clamp(min=self.epsilon)

def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.view(C, -1)

这就是我的训练代码的样子

for index,(image, mask) in enumerate(train_loader):
        image = torch.unsqueeze(image,0).float().cuda()
        label = mask.cuda()
        # label = label.view(-1,)
        output_1, output_2 = model(image)
        max_output = torch.argmax(output_2, 1)
        loss = loss_function(max_output,label)#max output shape =torch.Size([1, 128, 128, 128]) and labe shape = torch.Size([1, 128, 128, 128])
        print(loss)
        losses.append(loss.item())
        optimizer = optim.Adam(model.parameters())
        optimizer.zero_grad()
        loss.backward()

但是当我运行它时,我得到了

 Traceback (most recent call last):
          File "train.py", line 89, in <module>
            loss.backward()
          File "/home/bubbles/.local/lib/python3.6/site-packages/torch/tensor.py", line 107, in backward
            torch.autograd.backward(self, gradient, retain_graph, create_graph)
          File "/home/bubbles/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
            allow_unreachable=True)  # allow_unreachable flag
        RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

我不确定该如何前进,任何建议都会很有帮助。 预先感谢

0 个答案:

没有答案