我有一个整数的2D张量,我想找到其列包含任何指定值的行的索引。
例如,给定data
张量
data = torch.randint(10, (10,5))
tensor([[4, 7, 9, 8, 5],
[7, 4, 4, 3, 3],
[4, 9, 7, 7, 0],
[8, 1, 4, 6, 0],
[5, 9, 9, 5, 8],
[9, 3, 7, 6, 5],
[0, 2, 3, 5, 2],
[4, 4, 1, 5, 1],
[9, 8, 3, 7, 1],
[3, 2, 0, 4, 7]])
以及这些值的列表(或张量)
col1_values = [4, 5]
col2_values = [9, 4]
我想这样获取索引:
tensor([2, 4, 7])
我知道我可以将布尔型掩码一对一地组合
filter = ((data[:,0] == 4) + (data[:,0] == 5)) * ((data[:,1] == 9) + (data[:,1] == 4))
indices = filter.nonzero().squeeze()
我可以使用循环自动执行。但是有没有办法使用pytorch函数更有效地做到这一点?