我想将2D张量中的所有0替换为-5。
有了数据框,我可以轻松做到这一点:
df = df.mask(df=0, -5)
但这不适用于张量。我尝试过:
y = torch.where(y = 0, -5, y)
答案 0 :(得分:2)
很简单,只需使用此
y[y==0]=-5
答案 1 :(得分:2)
一般有两种方法。
上面由prhmma
给出的一种方法是使用就地突变,例如y[y == 0] = -5
。它既好又高效,但是会中断自动分级操作。因此,如果您希望渐变流过y,则不应这样做。
另一种方法是尝试使用torch.where
。正确的咒语是
y = torch.where(y == 0, torch.tensor(-5), y)
或者,如果要与设备和dtype无关,请
five = torch.tensor(-5, dtype=y.dtype, device=y.device)
y = torch.where(y == 0, five, y)
where
不接受标量的事实是令人讨厌的剪纸,但这就是ATM的样子。请注意,尽管选择本身是离散的并且显然不可区分,但此操作将使梯度流过两个操作数。