如何过滤特定维度上的元素

时间:2020-09-05 22:45:20

标签: numpy torch

假设张量a的形状为(128, 20, 10)。我想根据以下条件将此张量过滤为形状为(128,19,10)的张量b:在每个(20, 10)矩阵中,我想删除一行,其中列的总和为零。如何通过切片做到这一点?

我应该能够执行以下操作:

mask = a.abs().sum(dim=2) != 0
a = a[mask]

但这给了我错误的形状。

0 个答案:

没有答案