有没有一种方法可以使用numpy的数组索引对列执行多次检查?

时间:2019-04-12 21:44:52

标签: python arrays numpy

我有一个2D数据数组,并且我正在尝试从该数据中有效修剪不良列。我正在尝试删除所有包含值0,最小值和最大值之间的绝对差大于12或值大于9.5的列。

我拥有的代码可以运行,但是速度很慢。据我了解,在后台,对于这些代码行中的每行,我的数组都有一个循环。我想知道是否有一种方法可以将其减少到一个循环。

import numpy as np

data_array = data_array[:,abs(data_array).min(0)!=0]
data_array = data_array[:,abs(data_array.min(0)-data_array.max(0)) < 12]
data_array = data_array[:,abs(data_array).max(0) < 9.5]

1 个答案:

答案 0 :(得分:0)

我认为不可能在一个循环中执行这三个检查。

通过适当地排序修剪操作,您可能会提高性能。确实,您应该检查首先删除最多列的条件,以便传递给第二个过滤器的数组尽可能小。其余过滤器也适用相同的条件。

根据注释,您的数据范围从-3030。可以预期,最常见的无效列是那些包含大于9.5的值的列。我还猜测为什么丢弃列的最不频繁的原因是存在零值。如果这些假设不正确,则应相应更改过滤器的顺序。通过删除不必要的函数调用(例如abs)可以实现进一步的改进。

以下功能以如上所述的不同顺序实现相同的过滤操作:

import numpy as np

def trim(x, low=0, high=9.5, diff=12):
    x = x[:, np.all(x != 0, axis=0)]
    x = x[:, np.ptp(x, axis=0) <= diff]
    x = x[:, np.all(x <= high, axis=0)]
    return x

def trim_reordered(x, low=0, high=9.5, diff=12):
    x = x[:, np.all(x <= high, axis=0)]
    x = x[:, np.ptp(x, axis=0) <= diff]
    x = x[:, np.all(x != 0, axis=0)]
    return x

演示

In [205]: np.random.seed(213)

In [206]: small_arr = np.random.randint(low=-30, high=30, size=(3, 10))

In [207]: small_arr
Out[207]: 
array([[ 13,   6,   2, -29,  13,  11, -12, -24,   5,   9],
       [ 29,  24,  16, -21, -27,  -5,  -5, -16,  21, -29],
       [-10,  10, -24, -10,   4,   0,  -8, -23,   0,   4]])

In [208]: trim(small_arr)
Out[208]: 
array([[-12, -24],
       [ -5, -16],
       [ -8, -23]])

In [209]: large_arr = np.random.randint(low=-30, high=30, size=(10, 10**6))

In [210]: %timeit trim(large_arr)
77.3 ms ± 470 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [211]: %timeit trim_reordered(large_arr)
16.1 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [212]: np.all(trim(large_arr) == trim_reordered(large_arr))
Out[212]: True