我有一个包含0和255的1650行和1275列的numpy数组。 我想得到行中每个第一个零的索引并将其存储在一个数组中。 我用循环来实现这一点。这是示例代码
#new_arr is a numpy array and k is an empty array
for i in range(new_arr.shape[0]):
if not np.all(new_arr[i,:]) == 255:
x = np.where(new_arr[i,:]==0)[0][0]
k.append(x)
else:
k.append(-1)
1650行大约需要1.3秒。有没有其他方法或函数以更快的方式获取索引数组?
答案 0 :(得分:5)
一种方法是使用==0
获取匹配的掩码,然后沿每行获取argmax
,即为argmax(axis=1)
提供每行的第一个匹配索引 -
(arr==0).argmax(axis=1)
示例运行 -
In [443]: arr
Out[443]:
array([[0, 1, 0, 2, 2, 1, 2, 2],
[1, 1, 2, 2, 2, 1, 0, 1],
[2, 1, 0, 1, 0, 0, 2, 0],
[2, 2, 1, 0, 1, 2, 1, 0]])
In [444]: (arr==0).argmax(axis=1)
Out[444]: array([0, 6, 2, 3])
捕获非零行(如果可以的话,!)
为了方便那些没有任何零的行,我们需要再做一步工作,并进行一些掩蔽 -
In [445]: arr[2] = 9
In [446]: arr
Out[446]:
array([[0, 1, 0, 2, 2, 1, 2, 2],
[1, 1, 2, 2, 2, 1, 0, 1],
[9, 9, 9, 9, 9, 9, 9, 9],
[2, 2, 1, 0, 1, 2, 1, 0]])
In [447]: mask = arr==0
In [448]: np.where(mask.any(1), mask.argmax(1), -1)
Out[448]: array([ 0, 6, -1, 3])