pytorch中的张量是否有类似df.mask的类似功能?

时间:2019-11-25 18:05:08

标签: python tensorflow pytorch tensor

我想将2D张量中的所有0替换为-5。

有了数据框,我可以轻松做到这一点:

df = df.mask(df=0, -5) 

但这不适用于张量。我尝试过:

 y = torch.where(y = 0, -5, y) 

2 个答案:

答案 0 :(得分:2)

很简单,只需使用此

y[y==0]=-5

答案 1 :(得分:2)

一般有两种方法。

上面由prhmma给出的一种方法是使用就地突变,例如y[y == 0] = -5。它既好又高效,但是会中断自动分级操作。因此,如果您希望渐变流过y,则不应这样做。

另一种方法是尝试使用torch.where。正确的咒语是

y = torch.where(y == 0, torch.tensor(-5), y)

或者,如果要与设备和dtype无关,请

five = torch.tensor(-5, dtype=y.dtype, device=y.device)
y = torch.where(y == 0, five, y)

where不接受标量的事实是令人讨厌的剪纸,但这就是ATM的样子。请注意,尽管选择本身是离散的并且显然不可区分,但此操作将使梯度流过两个操作数。