在numpy数组的每一行中获取第一个零的索引数组

时间:2017-05-31 13:55:34

标签: python numpy

我有一个包含0和255的1650行和1275列的numpy数组。 我想得到行中每个第一个零的索引并将其存储在一个数组中。 我用循环来实现这一点。这是示例代码

#new_arr is a numpy array and k is an empty array 
for i in range(new_arr.shape[0]):
  if not np.all(new_arr[i,:]) == 255:
   x = np.where(new_arr[i,:]==0)[0][0]
   k.append(x)
  else:
   k.append(-1)

1650行大约需要1.3秒。有没有其他方法或函数以更快的方式获取索引数组?

1 个答案:

答案 0 :(得分:5)

一种方法是使用==0获取匹配的掩码,然后沿每行获取argmax,即为argmax(axis=1)提供每行的第一个匹配索引 -

(arr==0).argmax(axis=1)

示例运行 -

In [443]: arr
Out[443]: 
array([[0, 1, 0, 2, 2, 1, 2, 2],
       [1, 1, 2, 2, 2, 1, 0, 1],
       [2, 1, 0, 1, 0, 0, 2, 0],
       [2, 2, 1, 0, 1, 2, 1, 0]])

In [444]: (arr==0).argmax(axis=1)
Out[444]: array([0, 6, 2, 3])

捕获非零行(如果可以的话,

为了方便那些没有任何零的行,我们需要再做一步工作,并进行一些掩蔽 -

In [445]: arr[2] = 9

In [446]: arr
Out[446]: 
array([[0, 1, 0, 2, 2, 1, 2, 2],
       [1, 1, 2, 2, 2, 1, 0, 1],
       [9, 9, 9, 9, 9, 9, 9, 9],
       [2, 2, 1, 0, 1, 2, 1, 0]])

In [447]: mask = arr==0

In [448]: np.where(mask.any(1), mask.argmax(1), -1)
Out[448]: array([ 0,  6, -1,  3])