PyTorch:查找满足给定条件的张量行的索引

时间:2019-10-19 13:55:12

标签: pytorch tensor indices

我有一个整数的2D张量,我想找到其列包含任何指定值的行的索引。

例如,给定data张量

data = torch.randint(10, (10,5))
tensor([[4, 7, 9, 8, 5],
        [7, 4, 4, 3, 3],
        [4, 9, 7, 7, 0],
        [8, 1, 4, 6, 0],
        [5, 9, 9, 5, 8],
        [9, 3, 7, 6, 5],
        [0, 2, 3, 5, 2],
        [4, 4, 1, 5, 1],
        [9, 8, 3, 7, 1],
        [3, 2, 0, 4, 7]])

以及这些值的列表(或张量)

col1_values = [4, 5]
col2_values = [9, 4]

我想这样获取索引:

tensor([2, 4, 7])

我知道我可以将布尔型掩码一对一地组合

filter = ((data[:,0] == 4) + (data[:,0] == 5)) * ((data[:,1] == 9) + (data[:,1] == 4))
indices = filter.nonzero().squeeze()

我可以使用循环自动执行。但是有没有办法使用pytorch函数更有效地做到这一点?

0 个答案:

没有答案