查找最后一个非零元素3D数组 - numpy数组

时间:2017-07-31 18:47:48

标签: python numpy

我需要找到一种方法来执行此操作,我有一个形状数组

[batch_size,150,1]

表示batch_size整数序列,每个序列长150个元素,但是每个序列都添加了很多零,以便生成相同长度的所有序列。我需要找到每个序列的最后一个非零元素,并将它放在一个数组中,这个数组的形状必须是[batch_size]。我会尝试以下方法:

last = []
for j in range(0 , inputs.shape[0] ):
  tnew = np.array( inputs[j][:][0] )
  tnew = np.trim_zeros(tnew )
   last.append( int(tnew[-1]) )

但我不知道是否有更好的方法来做到这一点,不必像这样循环遍历每个元素。

感谢您的回答和帮助。

测试数据

a = np.array([[[1],[0],[0],[0],[0],[0]],
              [[1],[2],[0],[0],[0],[0]],
              [[1],[2],[3],[0],[0],[0]],
              [[1],[2],[3],[4],[0],[0]],
              [[1],[2],[3],[4],[5],[0]]])

1 个答案:

答案 0 :(得分:2)

这是一种矢量化方法 -

a.shape[1] - (a!=0)[:,::-1].argmax(1) - 1

示例运行 -

In [191]: a = np.random.randint(0,3,(3,6,1))

In [192]: a
Out[192]: 
array([[[2],
        [1],
        [2],
        [2],
        [2],
        [0]],

       [[2],
        [1],
        [1],
        [0],
        [2],
        [0]],

       [[2],
        [1],
        [2],
        [0],
        [1],
        [1]]])

In [193]: a.shape[1] - (a!=0)[:,::-1].argmax(1) - 1
Out[193]: 
array([[4],
       [4],
       [5]])