沿两个维度的 Torch 布尔索引

时间:2021-06-17 08:22:23

标签: python pytorch

假设您有一个二维平方张量:

x = torch.tensor([[ 0,  1,  2,  3,  4,  5],
                  [ 6,  7,  8,  9, 10, 11],
                  [12, 13, 14, 15, 16, 17],
                  [18, 19, 20, 21, 22, 23],
                  [24, 25, 26, 27, 28, 29],
                  [30, 31, 32, 33, 34, 35]])

考虑到您有一个张量 keep,例如:

keep = torch.tensor([True, False, True, True, False, False])

所需的输出是:

tensor([[ 0,  2,  3],
        [12, 14, 15],
        [18, 20, 21]])

一些不起作用的东西

我希望 x[keep, keep] 可以工作,但它只选择对角线上的元素。

让它长期有效 - 面具

一种方法是使用面具,但它很乏味:

mask = keep.view(-1, 1) * keep
submatrix_size = keep.sum()
x[mask].view(sub_matrix_size, -1)

让它在短时间内发挥作用

另一种方法是:

x[keep][:, keep]

我的问题是:short way 是在具有相同布尔张量的两个维度上进行选择的最佳方法吗?在 PyTorch 中还有其他方法吗?

0 个答案:

没有答案