我的 PyTorch model
输出两个图像:op
和 psuedo-op
,我希望仅在 loss(op_i,gt_i)<loss(psuedo-op_i,gt_i)
、i
用于索引像素。显然,loss(op,gt).backward()
无法做到这一点。那怎么办?
有一些粗略的解决方案如下:
def loss(pred, target):
""" shape of pred/target is BatchxChannelxHeightxWidth, where Channel=1 """
abs_diff = torch.abs(target - pred)
l1_loss = abs_diff.mean(1, True)
return l1_loss
o1 = loss(op,gt)
o2 = loss(psuedo-op,gt)
o = torch.cat((o1,o2),dim=1)
value, idx = torch.min(o,dim=1)
NOW SOMEHOW USE IDX TO GENERATE MASK AND SELECTIVE BACKPROPAGATION
如果它允许我在 o1
上进行反向传播,但仅适用于 o1<o2
的那些像素,任何其他解决方案也可以使用。
答案 0 :(得分:1)
我认为您可以为此目的使用 relu
函数。由于您只需要在 o1
上进行反向传播,您首先需要 detach
损失 o2
。并且还有一个减号来校正梯度的符号。
# This diff is 0 when o1 > o2, equal to o2-o1 otherwise
o_diff = nn.functional.relu(o2.detach()-o1)
# gradient of (-relu(b-x)) is 0 if b-x < 0, 1 otherwise
(-o_diff).sum().backward()
在这里,使用 relu
作为对 o2-o1
符号的一种条件可以很容易地消除带负号系数的梯度
我需要强调的是,由于 o2
与图形分离,因此它相对于您的网络是一个常数,因此它不会影响梯度,因此此操作实现了您所需要的:它基本上是反向传播 { {1}} 如果 d/dx(-relu(b-o1(x))
为 0,否则为 b < o1(x)
(其中 d/dx(o1(x))
为常数)。