这个问题听起来很基础。但是,当我尝试在numpy数组上使用where
或boolean
条件时,它总是返回一个扁平化的数组。
我有NumPy数组
P = array([[ 0.49530662, 0.07901 , -0.19012371],
[ 0.1421513 , 0.48607405, -0.20315014],
[ 0.76467375, 0.16479826, -0.56598029],
[ 0.53530718, -0.21166188, -0.08773241]])
我只想提取负值数组,但是当我尝试
P[P<0]
array([-0.19012371, -0.41421612, -0.20315014, -0.56598029, -0.21166188,
-0.08773241, -0.09241335])
P[np.where(P<0)]
array([-0.19012371, -0.41421612, -0.20315014, -0.56598029, -0.21166188,
-0.08773241, -0.09241335])
我得到一个展平的数组。如何提取表格数组
array([[ 0, 0, -0.19012371],
[ 0 , 0, -0.20315014],
[ 0, 0, -0.56598029],
[ 0, -0.21166188, -0.08773241]])
我不希望创建一个临时数组,然后使用类似Temp[Temp>=0] = 0
答案 0 :(得分:3)
因为您的需要是
我想“提取”只有负值的数组
您可以将numpy.where()
与您的条件一起使用(检查负值),这样可以保留数组的维,如下例所示:
In [61]: np.where(P<0, P, 0)
Out[61]:
array([[ 0. , 0. , -0.19012371],
[ 0. , 0. , -0.20315014],
[ 0. , 0. , -0.56598029],
[ 0. , -0.21166188, -0.08773241]])
其中P
是您的输入数组。
另一个想法可能是使用numpy.zeros_like()
初始化相同的形状数组,并使用numpy.where()
收集满足我们条件的索引。
# initialize our result array with zeros
In [106]: non_positives = np.zeros_like(P)
# gather the indices where our condition is obeyed
In [107]: idxs = np.where(P < 0)
# copy the negative values to correct indices
In [108]: non_positives[idxs] = P[idxs]
In [109]: non_positives
Out[109]:
array([[ 0. , 0. , -0.19012371],
[ 0. , 0. , -0.20315014],
[ 0. , 0. , -0.56598029],
[ 0. , -0.21166188, -0.08773241]])
另一种想法是简单地使用准系统numpy.clip()
API,如果我们省略out=
kwarg,它将返回一个新数组。
In [22]: np.clip(P, -np.inf, 0) # P.clip(-np.inf, 0)
Out[22]:
array([[ 0. , 0. , -0.19012371],
[ 0. , 0. , -0.20315014],
[ 0. , 0. , -0.56598029],
[ 0. , -0.21166188, -0.08773241]])
答案 1 :(得分:0)
这应该起作用,本质上是获取大于0的所有元素的索引,并将它们设置为0,这样可以保留尺寸!我从这里得到这个主意:Replace all elements of Python NumPy Array that are greater than some value
还请注意,我已经修改了原始数组,这里没有使用临时数组
import numpy as np
P = np.array([[ 0.49530662, 0.07901 , -0.19012371],
[ 0.1421513 , 0.48607405, -0.20315014],
[ 0.76467375, 0.16479826, -0.56598029],
[ 0.53530718, -0.21166188, -0.08773241]])
P[P >= 0] = 0
print(P)
输出将为
[[ 0. 0. -0.19012371]
[ 0. 0. -0.20315014]
[ 0. 0. -0.56598029]
[ 0. -0.21166188 -0.08773241]]
如下所述,这将修改数组,因此我们应使用np.where(P<0, P 0)
来保存原始数组,如下所示,感谢@ kmario123如下
import numpy as np
P = np.array([[ 0.49530662, 0.07901 , -0.19012371],
[ 0.1421513 , 0.48607405, -0.20315014],
[ 0.76467375, 0.16479826, -0.56598029],
[ 0.53530718, -0.21166188, -0.08773241]])
print( np.where(P<0, P, 0))
print(P)
输出将为
[[ 0. 0. -0.19012371]
[ 0. 0. -0.20315014]
[ 0. 0. -0.56598029]
[ 0. -0.21166188 -0.08773241]]
[[ 0.49530662 0.07901 -0.19012371]
[ 0.1421513 0.48607405 -0.20315014]
[ 0.76467375 0.16479826 -0.56598029]
[ 0.53530718 -0.21166188 -0.08773241]]