我有一个二进制的numpy掩码数组,当至少3个连续出现1时,我想查找沿轴= 0的元素的索引。如果没有出现,则-999或NaN或任何表示它的东西是不是索引。 例如,我的数组就像:
masked_array(
data=[[[1.0, 0.0],
[0.0, 1.0]],
[[0.0, 1.0],
[0.0, 1.0]],
[[1.0, 1.0],
[1.0, 1.0]],
[[1.0, 1.0],
[1.0, 0.0]],
[[1.0, --],
[0.0, 1.0]],
[[1.0, 1.0],
[1.0, 1.0]]])
我想得到这样的东西:
array([[ 2, 1],
[-999, 0]])
最Python的方式是什么?任何提示将不胜感激。
答案 0 :(得分:2)
IIUC,您可以首先将np数组制作为2D并构建一个数据帧,这使一切变得更加容易。看看
row, cols = m.shape[0], m.shape[1] * m.shape[2]
df = pd.DataFrame(m.reshape(row, cols))
0 1 2 3
0 1.0 0.0 0.0 1.0
1 0.0 1.0 0.0 1.0
2 1.0 1.0 1.0 1.0
3 1.0 1.0 1.0 0.0
4 1.0 0.0 0.0 1.0
5 1.0 1.0 1.0 1.0
现在,您可以在rolling
上使用3
的反向axis=0
窗口,并检查all
元素是否为1
ndf = df[::-1].rolling(3, axis=0).apply(all, raw=True)[::-1]
0 1 2 3
0 NaN NaN NaN 1.0
1 NaN 1.0 NaN NaN
2 1.0 NaN NaN NaN
3 1.0 NaN NaN NaN
4 NaN NaN NaN NaN
5 NaN NaN NaN NaN
并使用idxmax()
获取第一次出现的1
的索引
ndf[ndf>=1].idxmax()
0 2.0
1 1.0
2 NaN
3 0.0
dtype: float
要形象化您的描述方式,只需调整输出形状
ndf[ndf>=1].idxmax().values.reshape(m.shape[1], m.shape[2])
array([[ 2., 1.],
[nan, 0.]])