PyTorch使用布尔蒙版提取张量元素(保留尺寸)

时间:2020-06-17 18:56:20

标签: boolean pytorch mask

说,我有一个PyTorch 2x2张量,并且还生成了一个相同尺寸(2x2)的布尔张量。我想用它作为面具。

例如,如果我有张量:

tensor([[1, 3],
        [4, 7]])

如果我的面具是:

tensor([[ True, False],
        [False,  True]])

我想使用该掩码获取张量,其中保留与原始张量中的True相对应的元素,而与False相对应的元素在输出张量中设置为零。

预期输出:

tensor([[1, 0],
        [0, 7]])

感谢您的帮助。谢谢!

1 个答案:

答案 0 :(得分:1)

假设您有:

t = torch.Tensor([[1,2], [3,4]])
mask = torch.Tensor([[True,False], [False,True]])

您可以通过以下方式使用面具:

masked_t = t * mask

,输出将是:

tensor([[1., 0.],
        [0., 4.]])