我有一个PyTorch
张量a
,形状如下:
import torch
a = torch.tensor([[[1., 0., 0., 0.]],
[[0., 1., 0., 0.]],
[[1., 0., 0., 0.]],
[[0., 0., 0., 1.]],
[[1., 0., 0., 0.]],
[[0., 0., 0., 1.]],
[[1., 0., 0., 0.]]])
张量a
的每一行都有4个元素,分别为1和0。说我索引此张量的行和列。因此,例如,第0行(最上一行)中的条目为[[1., 0., 0., 0.]]
,而第3列(最右列)中的条目为[[0., 0., 0., 1., 0., 1., 0.]]
。
从给定的张量中,我想确定1.
最常出现的列的索引。例如,对于张量a
,此类列的索引将为0。如果数量为1,则存在联系,我仍然希望获得所有这些约束列索引。>
如何在Python上轻松完成此任务?
谢谢
答案 0 :(得分:1)
如果矩阵仅包含0和1,则可以对每一列的元素求和,然后搜索最大的和:
import numpy as np
% sum over columns
sumsi = torch.sum(a, dim=1)
% find where maximum
col_idx = np.where(sumsi==np.max(sumsi))