如何返回基于一维掩码截断的 k-dim pytorch 张量

时间:2021-03-23 15:37:49

标签: python pytorch boolean

所以假设我有一个 k-dim 张量和一个 1-dim 掩码,它在 pytorch 中经常用于可变长度序列,我想返回一个张量,该张量表示直到掩码中第一个假值的元素.下面是一个例子:

import torch

a = torch.tensor([[1,2],[3,4],[5,6],[0,0],[0,0],[0,0]])
b = torch.tensor([True,True,True,False,False,False])

# magic goes here, result of c should be:

print(c)
>>> [[1,2],[3,4],[5,6]]

在这个例子中,输入张量是二维的,但它可以是 k-d,在这些维度上有任意数量的值。只有第一个维度需要匹配掩码维度。因此,执行 torch.masked_select 不起作用,因为要截断的张量不像掩码那样是一维的,而且由于您不知道维度,因此挤压和解压也不是解决方案。

掩码对于前 k 个元素始终为真,对于其余元素为假,但如果您的解决方案不“依赖”于此,那很好。

这似乎人们会一直这样做,但我找不到任何地方已经回答了这个问题。

1 个答案:

答案 0 :(得分:0)

您可以简单地将掩码作为切片索引传递给张量:

c = a[b]
>>> c
tensor([[1, 2],
        [3, 4],
        [5, 6]])