根据条件在numpy数组的元素中进行数学运算的有效方法

时间:2015-05-29 14:38:33

标签: python numpy optimization

我正在尝试优化我的python代码。其中一个瓶颈就出现了 我试图根据每个元素值将函数应用于numpy数组。例如,我有一个包含数千个元素的数组,我为一个大于容差的值应用了一个函数,为其余的函数应用了另一个函数(泰勒系列)。我做掩蔽但仍然很慢,至少我调用了6400万次以下的功能。

EPSILONZETA = 1.0e-6
ZETA1_12 = 1.0/12.0
ZETA1_720 = 1.0/720.0

def masked_condition_zero(array, tolerance):
    """ Return the indices where values are lesser (and greater) than tolerance
    """
    # search indices where array values < tolerance
    indzeros_ = np.where(np.abs(array) < tolerance)[0]

    # create mask
    mask_ = np.ones(np.shape(array), dtype=bool)

    mask_[[indzeros_]] = False

    return (~mask_, mask_) 

def bernoulli_function1(zeta):
    """ Returns the Bernoulli function of zeta, vector version
    """
    # get the indices according to condition
    zeros_, others_ = masked_condition_zero(zeta, EPSILONZETA)

    # create an array filled with zeros
    fb_ = np.zeros(np.shape(zeta))

    # Apply the original function to the values greater than EPSILONZETA
    fb_[others_] = zeta[others_]/(np.exp(zeta[others_])-1.0)  

    # computes series for zeta < eps
    zeta0_ = zeta[zeros_]
    zeta2_ = zeta0_ *  zeta0_
    zeta4_ =  zeta2_ * zeta2_
    fb_[zeros_] = 1.0 - 0.5*zeta0_ + ZETA1_12 * zeta2_ - ZETA1_720 * zeta4_
    return fb_

现在假设你有一个带有负浮动和正浮动的数组zeta,它在每个循环中发生变化,延伸到2 ^ 26次迭代,你想每次计算fbernoulli_function1(zeta)。

有更好的解决方案吗?

3 个答案:

答案 0 :(得分:2)

问题的基本结构是:

class CreatorStudio extends Polymer.Class({}) {

  is = "creator-studio";
}

/*(function () {
Polymer(CreatorStudio.prototype);
})();*/
document.registerElement("creator-studio", CreatorStudio.prototype);

看起来您的多项式表达式可以在所有def foo(zeta): result = np.empty_like(zeta) I = condition(zeta) nI = ~I result[I] = func1(zeta[I]) result[nI] = func2(zeta[nI]) 进行评估,但它是“例外”,即zeta太接近0时的后退计算。

如果可以为zeta评估两个函数,则可以使用where:

zeta

这是精简版:

np.where(condition(zeta), func1(zeta), func2(zeta))

另一个选择是将一个函数应用于所有值,另一个选项仅应用于“例外”。

def foo(zeta):
    result = np.empty_like(zeta)
    I = condition(zeta)
    nI = ~I
    v1 = func1(zeta)
    v2 = func2(zeta)
    result[I] = v1[I]
    result[nI] = v2[nI]

当然反过来 - def foo(zeta): result = func2(zeta) I = condition(zeta) result[I] = func1[zeta[I]]

在我的简短时间测试中,result = func1(zeta); result[nI]=func2[zeta]func1需要大约相同的时间。

func2也需要花费时间,但更简单的masked_condition_zero(以及它的np.abs(array) < tolerance)会减少一半。

让我们比较分配策略

~J

对于def foo(zeta, J, nJ): result = np.empty_like(zeta) result[J] = fun1(zeta[J]) result[nJ] = fun2(zeta[nJ]) return result 为完整zeta[J]的10%的样本,某些采样时间为:

zeta

第二种情况最快,因为在较少的值上运行In [127]: timeit foo(zeta, J, nJ) 10000 loops, best of 3: 55.7 µs per loop In [128]: timeit result=fun2(zeta); result[J]=fun1(zeta[J]) 10000 loops, best of 3: 49.2 µs per loop In [129]: timeit np.where(J, fun1(zeta),fun2(zeta)) 10000 loops, best of 3: 73.4 µs per loop In [130]: timeit result=fun1(zeta); result[nJ]=fun2(zeta[nJ]) 10000 loops, best of 3: 60.7 µs per loop 会补偿索引fun1的额外成本。在索引成本和功能评估成本之间存在权衡。像这样的布尔索引比切片更昂贵。对于其他价值观的混合,时间可以转向另一个方向。

这看起来像是一个你可以随时贬低的问题,但我看不到任何突破,虽然会使时间缩短一个数量级。

答案 1 :(得分:0)

与索引到数组相比,where命令相当慢。这可能会更快。

fb_ = np.zeros_like(zeta)
nonZero= zeta > ZETA_TOLERANCE
zero = ~nonZero
fb_[zero] = function1(zeta[zero])
fb_[nonZero] = function2(zeta[nonZero])

修改: 我意识到我的原始版本正在制作同一阵列的两个副本。这个新版本应该更快一些。

答案 2 :(得分:0)

您可以使用numba [1](如果您使用anaconda或类似的python发行版,则应安装),这是一个旨在使用numpy的jit编译器。

from numba import jit
@jit
def bernoulli_function_fill(zeta, fb_):
    for i in xrange(len(zeta)):
        if np.abs(zeta[i])>EPSILONZETA:
            fb_[i] = zeta[i]/(np.exp(zeta[i])-1.0)
        else:
            zeta0_ = zeta[i]
            zeta2_ = zeta0_ *  zeta0_
            zeta4_ =  zeta2_ * zeta2_
            fb_[i] = 1.0 - 0.5*zeta0_ + ZETA1_12 * zeta2_ - ZETA1_720 * zeta4_
def bernoulli_function_fast(zeta):
    fb_ = np.zeros_like(zeta)
    bernoulli_function_fill(zeta, fb_)
    return fb_

注意:如果您使用新版本的numba,可以将两者合并为相同的功能。

在我的机器上:

#create some test data
zeta = random.uniform(-1,1, size=2**24)
zeta[random.choice(len(zeta),size=2**23,replace=False )] = EPSILONZETA/2
>>> alltrue(bernoulli_function_fast(zeta)==bernoulli_function1(zeta))
True
>>> %timeit bernoulli_function1(zeta) # your function
1 loops, best of 3: 1.49 s per loop
>>> %timeit bernoulli_function_fast(zeta) #numba function
1 loops, best of 3: 347 ms per loop

快4倍,更容易阅读。