假设您有一个二维平方张量:
x = torch.tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
考虑到您有一个张量 keep
,例如:
keep = torch.tensor([True, False, True, True, False, False])
所需的输出是:
tensor([[ 0, 2, 3],
[12, 14, 15],
[18, 20, 21]])
我希望 x[keep, keep]
可以工作,但它只选择对角线上的元素。
一种方法是使用面具,但它很乏味:
mask = keep.view(-1, 1) * keep
submatrix_size = keep.sum()
x[mask].view(sub_matrix_size, -1)
另一种方法是:
x[keep][:, keep]
我的问题是:short way 是在具有相同布尔张量的两个维度上进行选择的最佳方法吗?在 PyTorch 中还有其他方法吗?