pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?

时间:2019-05-11 07:31:36

标签: python machine-learning pytorch

我有一个二维张量,每行中都有一些非零元素,像这样:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

我想要一个张量,在每行中包含第一个非零元素的索引:

indices = tensor([2],
                 [3])

如何在Pytorch中计算它?

2 个答案:

答案 0 :(得分:5)

我简化了Iman的方法来执行以下操作:

idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)

答案 1 :(得分:0)

我可以为我的问题找到一个棘手的答案:

  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

结果是:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])