如何在 PyTorch 中为损失函数添加掩码

时间:2021-06-21 06:46:26

标签: python pytorch

我的 PyTorch model 输出两个图像:oppsuedo-op,我希望仅在 loss(op_i,gt_i)<loss(psuedo-op_i,gt_i)i 用于索引像素。显然,loss(op,gt).backward() 无法做到这一点。那怎么办?

有一些粗略的解决方案如下:

def loss(pred, target):
       """ shape of pred/target is BatchxChannelxHeightxWidth, where Channel=1 """
        
        abs_diff = torch.abs(target - pred)
        l1_loss = abs_diff.mean(1, True)

        return l1_loss

o1 = loss(op,gt)
o2 = loss(psuedo-op,gt)
o = torch.cat((o1,o2),dim=1)

value, idx = torch.min(o,dim=1)

NOW SOMEHOW USE IDX TO GENERATE MASK AND SELECTIVE BACKPROPAGATION

如果它允许我在 o1 上进行反向传播,但仅适用于 o1<o2 的那些像素,任何其他解决方案也可以使用。

1 个答案:

答案 0 :(得分:1)

我认为您可以为此目的使用 relu 函数。由于您只需要在 o1 上进行反向传播,您首先需要 detach 损失 o2。并且还有一个减号来校正梯度的符号。

# This diff is 0 when o1 > o2, equal to o2-o1 otherwise
o_diff = nn.functional.relu(o2.detach()-o1)
# gradient of (-relu(b-x)) is 0 if b-x < 0, 1 otherwise
(-o_diff).sum().backward()

在这里,使用 relu 作为对 o2-o1 符号的一种条件可以很容易地消除带负号系数的梯度

我需要强调的是,由于 o2 与图形分离,因此它相对于您的网络是一个常数,因此它不会影响梯度,因此此操作实现了您所需要的:它基本上是反向传播 { {1}} 如果 d/dx(-relu(b-o1(x)) 为 0,否则为 b < o1(x)(其中 d/dx(o1(x)) 为常数)。