有效地选择numpy数组的子部分

时间:2014-10-20 10:16:42

标签: python arrays loops numpy

我想基于逻辑比较将numpy数组拆分为三个不同的数组。我要拆分的numpy数组称为x。它的形状如下所示,但它的条目各不相同:(响应Saullo Castro的评论,我包含了一个略有不同的数组x。)

array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
      [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

此数组的值沿列单调递增。我还有另外两个名为lowest_gridpointshighest_gridpoints的数组。这些数组的条目也各不相同,但形状始终与以下内容相同:

 array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

我想申请的选择程序如下:

  • 包含低于lowest_gridpoints中任何值的值的所有列都应从x中删除,并构成数组temp1
  • 所有包含高于highest_gridpoints中任何值的值的列都应从x中删除,并构成数组temp2
  • xtemp1中包含的所有temp2列都构成了数组x_new

我编写的以下代码实现了该任务。

if np.any( x[:,-1] > highest_gridpoints ) or np.any( x[:,0] < lowest_gridpoints ):
    for idx, sample, in enumerate(x.T):
        if np.any( sample > highest_gridpoints):
            max_idx = idx
            break
        elif np.any( sample < lowest_gridpoints ):
            min_idx = idx 
    temp1, temp2 = np.array([[],[]]), np.array([[],[]])
    if 'min_idx' in locals():
        temp1 = x[:,0:min_idx+1]
    if 'max_idx' in locals():
        temp2 = x[:,max_idx:]
    if 'min_idx' in locals() or 'max_idx' in locals():
        if 'min_idx' not in locals():
            min_idx = -1
        if 'max_idx' not in locals():
            max_idx = x.shape[1]
        x_new = x[:,min_idx+1:max_idx]

但是,我怀疑由于循环的大量使用,此代码效率非常低。另外,我认为语法很臃肿。

有人对代码有了一个想法,可以更有效地实现上述任务或看起来简洁吗?

1 个答案:

答案 0 :(得分:1)

仅限问题的第一部分

from numpy import *

x = array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

# construct an array of two rows of bools expressing your conditions
indices1 = array((x[0,:]<low[0], x[1,:]<low[1]))
print indices1

# do an or of the values along the first axis
indices1 = any(indices1, axis=0)
# now it's a single row array
print indices1

# use the indices1 to extract what you want,
# the double transposition because the elements
# of a 2d array are  the rows
tmp1 = x.T[indices1].T
print tmp1

# [[ True  True False False False]
#  [ True  True False False False]]
# [ True  True False False False]
# [[ 0.46006547  0.5580928 ]
#  [ 0.00912908  0.00912908]]

接下来构造类似indices2tmp2,余数的指数是对前两个指数or的否定。 (即numpy.logical_not(numpy.logical_or(i1,i2)))。

<强>附录

另一种方法,如果你有数千个条目可能会更快,这意味着numpy.searchsorted

from numpy import *

x = array([[ 0.46006547,  0.5580928 ,  0.70164242,  0.84519205,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05      ,  0.05      ,  0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')

h0l = searchsorted(x[0,:], high[0], side='left')
h1l = searchsorted(x[1,:], high[1], side='left')

lr = max(l0r, l1r)
hl = min(h0l, h1l)

print lr, hl
print x[:,:lr]
print x[:,lr:hl]
print x[:,hl]

# 2 4
# [[ 0.46006547  0.5580928 ]
#  [ 0.00912908  0.00912908]]
# [[ 0.70164242  0.84519205]
#  [ 0.05        0.05      ]]
# [ 1.4   0.05]

排除重叠可以通过hl = max(lr, hl)获得。 NB在previuos方法中,数组切片被复制到新对象,在这里您可以获得x的视图,如果您想要新对象,则必须明确。

修改 不必要的优化

如果我们在x es的第二对中仅使用sortedsearch的上半部分(如果你查看代码,你会看到我的意思......)我们得到两个好处,1)搜索速度非常小(sortedsearch总是足够快)和2)自动管理重叠的情况。

作为奖励,用于复制新阵列中x的片段的代码。 NB x已更改为强制重叠

from numpy import *

# I changed x to force overlap
x = array([[ 0.46006547,  1.4 ,        1.4,   1.4,  1.4       ],
           [ 0.00912908,  0.00912908,  0.05,  0.05, 0.05      ]])

low, high = array([ 0.633,  0.01 ]), array([ 1.325,  0.99 ])

l0r = searchsorted(x[0,:], low[0], side='right')
l1r = searchsorted(x[1,:], low[1], side='right')
lr = max(l0r, l1r)

h0l = searchsorted(x[0,lr:], high[0], side='left')
h1l = searchsorted(x[1,lr:], high[1], side='left')

hl = min(h0l, h1l) + lr

t1 = x[:,range(lr)]
xn = x[:,range(lr,hl)]
ncol = shape(x)[1]
t2 = x[:,range(hl,ncol)]

print x
del(x)
print
print t1
print
# note that xn is a void array 
print xn
print
print t2

# [[ 0.46006547  1.4         1.4         1.4         1.4       ]
#  [ 0.00912908  0.00912908  0.05        0.05        0.05      ]]
# 
# [[ 0.46006547  1.4       ]
#  [ 0.00912908  0.00912908]]
# 
# []
# 
# [[ 1.4   1.4   1.4 ]
#  [ 0.05  0.05  0.05]]