pytorch masked_fill:为什么我不能屏蔽所有零?

时间:2019-05-27 15:11:30

标签: python tensorflow pytorch

我想用-np.inf掩盖分数矩阵中的所有零,但我只能掩盖部分零,就像

enter image description here

您会在右上角看到仍然没有被-np.inf屏蔽的零。

这是我的代码:

q = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
k = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
scores = torch.matmul(q, k.transpose(0,1)) / math.sqrt(10)
mask = torch.Tensor([1,1,1,1,0,0])
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask==0, -np.inf)

也许面具错了吗?

4 个答案:

答案 0 :(得分:2)

您的面具是错误的。试试

scores = scores.masked_fill(scores == 0, -np.inf)

scores现在看起来像

tensor([[1.4796, 1.2361, 1.2137, 0.9487,   -inf,   -inf],
        [0.6889, 0.4428, 0.6302, 0.4388,   -inf,   -inf],
        [0.8842, 0.7614, 0.8311, 0.6431,   -inf,   -inf],
        [0.9884, 0.8430, 0.7982, 0.7323,   -inf,   -inf],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf]])

答案 1 :(得分:2)

在mujjiga的代码中,得分张量本身被用作掩码,因此尽管不是掩码的通常用途,它也会将所有0替换为-inf。遮罩通常独立于要遮罩的张量。

答案 2 :(得分:1)

是的,您的代码正确,输出显示正确的行为。 当前,您的蒙版具有[6,1]形状,因此将首先遮盖每列中的最后两个元素。

>>> mask = torch.Tensor([1,1,1,1,0,0])

>>> mask.shape

torch.Size([6])

>>> mask = mask.unsqueeze(1)

>>> mask.shape

torch.Size([6, 1])

答案 3 :(得分:0)

或者甚至通过稍微更改代码也可以使用

import math

q = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
k = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
scores = torch.matmul(q, k.transpose(0,1)) / math.sqrt(10)
mask = torch.Tensor([1,1,1,1,0,0])
mask2 = mask.unsqueeze(1)
scores = scores.masked_fill(mask2==0, -np.inf)
mask = mask.unsqueeze(0)
scores = scores.masked_fill(mask==0, -np.inf)
scores