所以假设我有一个 k-dim 张量和一个 1-dim 掩码,它在 pytorch 中经常用于可变长度序列,我想返回一个张量,该张量表示直到掩码中第一个假值的元素.下面是一个例子:
import torch
a = torch.tensor([[1,2],[3,4],[5,6],[0,0],[0,0],[0,0]])
b = torch.tensor([True,True,True,False,False,False])
# magic goes here, result of c should be:
print(c)
>>> [[1,2],[3,4],[5,6]]
在这个例子中,输入张量是二维的,但它可以是 k-d,在这些维度上有任意数量的值。只有第一个维度需要匹配掩码维度。因此,执行 torch.masked_select 不起作用,因为要截断的张量不像掩码那样是一维的,而且由于您不知道维度,因此挤压和解压也不是解决方案。
掩码对于前 k 个元素始终为真,对于其余元素为假,但如果您的解决方案不“依赖”于此,那很好。
这似乎人们会一直这样做,但我找不到任何地方已经回答了这个问题。
答案 0 :(得分:0)
您可以简单地将掩码作为切片索引传递给张量:
c = a[b]
>>> c
tensor([[1, 2],
[3, 4],
[5, 6]])