我想基于逻辑比较将numpy数组拆分为三个不同的数组。我要拆分的numpy数组称为x
。它的形状如下所示,但它的条目各不相同:(响应Saullo Castro的评论,我包含了一个略有不同的数组x。)
array([[ 0.46006547, 0.5580928 , 0.70164242, 0.84519205, 1.4 ],
[ 0.00912908, 0.00912908, 0.05 , 0.05 , 0.05 ]])
此数组的值沿列单调递增。我还有另外两个名为lowest_gridpoints
和highest_gridpoints
的数组。这些数组的条目也各不相同,但形状始终与以下内容相同:
array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
我想申请的选择程序如下:
lowest_gridpoints
中任何值的值的所有列都应从x
中删除,并构成数组temp1
。 highest_gridpoints
中任何值的值的列都应从x
中删除,并构成数组temp2
。 x
或temp1
中包含的所有temp2
列都构成了数组x_new
。 我编写的以下代码实现了该任务。
if np.any( x[:,-1] > highest_gridpoints ) or np.any( x[:,0] < lowest_gridpoints ):
for idx, sample, in enumerate(x.T):
if np.any( sample > highest_gridpoints):
max_idx = idx
break
elif np.any( sample < lowest_gridpoints ):
min_idx = idx
temp1, temp2 = np.array([[],[]]), np.array([[],[]])
if 'min_idx' in locals():
temp1 = x[:,0:min_idx+1]
if 'max_idx' in locals():
temp2 = x[:,max_idx:]
if 'min_idx' in locals() or 'max_idx' in locals():
if 'min_idx' not in locals():
min_idx = -1
if 'max_idx' not in locals():
max_idx = x.shape[1]
x_new = x[:,min_idx+1:max_idx]
但是,我怀疑由于循环的大量使用,此代码效率非常低。另外,我认为语法很臃肿。
有人对代码有了一个想法,可以更有效地实现上述任务或看起来简洁吗?
答案 0 :(得分:1)
仅限问题的第一部分
from numpy import *
x = array([[ 0.46006547, 0.5580928 , 0.70164242, 0.84519205, 1.4 ],
[ 0.00912908, 0.00912908, 0.05 , 0.05 , 0.05 ]])
low, high = array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
# construct an array of two rows of bools expressing your conditions
indices1 = array((x[0,:]<low[0], x[1,:]<low[1]))
print indices1
# do an or of the values along the first axis
indices1 = any(indices1, axis=0)
# now it's a single row array
print indices1
# use the indices1 to extract what you want,
# the double transposition because the elements
# of a 2d array are the rows
tmp1 = x.T[indices1].T
print tmp1
# [[ True True False False False]
# [ True True False False False]]
# [ True True False False False]
# [[ 0.46006547 0.5580928 ]
# [ 0.00912908 0.00912908]]
接下来构造类似indices2
和tmp2
,余数的指数是对前两个指数or
的否定。 (即numpy.logical_not(numpy.logical_or(i1,i2))
)。
<强>附录强>
另一种方法,如果你有数千个条目可能会更快,这意味着numpy.searchsorted
from numpy import *
x = array([[ 0.46006547, 0.5580928 , 0.70164242, 0.84519205, 1.4 ],
[ 0.00912908, 0.00912908, 0.05 , 0.05 , 0.05 ]])
low, high = array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
h0l = searchsorted(x[0,:], high[0], side='left')
h1l = searchsorted(x[1,:], high[1], side='left')
lr = max(l0r, l1r)
hl = min(h0l, h1l)
print lr, hl
print x[:,:lr]
print x[:,lr:hl]
print x[:,hl]
# 2 4
# [[ 0.46006547 0.5580928 ]
# [ 0.00912908 0.00912908]]
# [[ 0.70164242 0.84519205]
# [ 0.05 0.05 ]]
# [ 1.4 0.05]
排除重叠可以通过hl = max(lr, hl)
获得。 NB在previuos方法中,数组切片被复制到新对象,在这里您可以获得x
的视图,如果您想要新对象,则必须明确。
修改 不必要的优化
如果我们在x
es的第二对中仅使用sortedsearch
的上半部分(如果你查看代码,你会看到我的意思......)我们得到两个好处,1)搜索速度非常小(sortedsearch
总是足够快)和2)自动管理重叠的情况。
作为奖励,用于复制新阵列中x
的片段的代码。 NB x
已更改为强制重叠
from numpy import *
# I changed x to force overlap
x = array([[ 0.46006547, 1.4 , 1.4, 1.4, 1.4 ],
[ 0.00912908, 0.00912908, 0.05, 0.05, 0.05 ]])
low, high = array([ 0.633, 0.01 ]), array([ 1.325, 0.99 ])
l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
lr = max(l0r, l1r)
h0l = searchsorted(x[0,lr:], high[0], side='left')
h1l = searchsorted(x[1,lr:], high[1], side='left')
hl = min(h0l, h1l) + lr
t1 = x[:,range(lr)]
xn = x[:,range(lr,hl)]
ncol = shape(x)[1]
t2 = x[:,range(hl,ncol)]
print x
del(x)
print
print t1
print
# note that xn is a void array
print xn
print
print t2
# [[ 0.46006547 1.4 1.4 1.4 1.4 ]
# [ 0.00912908 0.00912908 0.05 0.05 0.05 ]]
#
# [[ 0.46006547 1.4 ]
# [ 0.00912908 0.00912908]]
#
# []
#
# [[ 1.4 1.4 1.4 ]
# [ 0.05 0.05 0.05]]