我正在寻找一种基于列表来过滤numpy数组的方法
input_array = [[0,4,6],[2,1,1],[6,6,9]]
list=[9,4]
...
output_array = [[0,1,0],[0,0,0],[0,0,1]]
我目前正在展平数组,并将其转换为列表并返回。看起来非常不熟悉:
list=[9,4]
shape = input_array.shape
input_array = input_array.flatten()
output_array = np.array([int(i in list) for i in input_array])
output_array = output_array.reshape(shape)
答案 0 :(得分:2)
我们可以使用np.in1d
来获取匹配的掩码。现在,np.in1d
在处理之前将输入展平为1D
。因此,它的输出将重新转换回2D
,然后转换为int
以获得0s
和1s
的输出。
因此,实施将是 -
np.in1d(input_array, list).reshape(input_array.shape).astype(int)
示例运行 -
In [40]: input_array
Out[40]:
array([[0, 4, 6],
[2, 1, 1],
[6, 6, 9]])
In [41]: list=[9,4]
In [42]: np.in1d(input_array, list).reshape(input_array.shape).astype(int)
Out[42]:
array([[0, 1, 0],
[0, 0, 0],
[0, 0, 1]])