查找具有最多1的PyTorch张量的列索引

时间:2020-08-09 18:36:49

标签: python pytorch tensor

我有一个PyTorch张量a,形状如下:

import torch
a = torch.tensor([[[1., 0., 0., 0.]],
        [[0., 1., 0., 0.]],
        [[1., 0., 0., 0.]],
        [[0., 0., 0., 1.]],
        [[1., 0., 0., 0.]],
        [[0., 0., 0., 1.]],
        [[1., 0., 0., 0.]]])

张量a的每一行都有4个元素,分别为1和0。说我索引此张量的行和列。因此,例如,第0行(最上一行)中的条目为[[1., 0., 0., 0.]],而第3列(最右列)中的条目为[[0., 0., 0., 1., 0., 1., 0.]]

从给定的张量中,我想确定1.最常出现的列的索引。例如,对于张量a,此类列的索引将为0。如果数量为1,则存在联系,我仍然希望获得所有这些约束列索引。

如何在Python上轻松完成此任务?

谢谢

1 个答案:

答案 0 :(得分:1)

如果矩阵仅包含0和1,则可以对每一列的元素求和,然后搜索最大的和:

import numpy as np

% sum over columns
sumsi = torch.sum(a, dim=1)

% find where maximum
col_idx = np.where(sumsi==np.max(sumsi))