优化Numba和Numpy功能

时间:2019-06-14 09:35:30

标签: python optimization numba

我试图使这段代码运行得更快,但是我找不到其他可以加快速度的技巧。

我得到大约3微秒的运行时间,问题是我将此函数调用了数百万次,而该过程最终要花费很长时间。我在Java中具有相同的实现(仅具有基本的for循环),并且基本上,即使对于大型训练数据(这是针对ANN),计算也是即时的

有没有办法加快速度?

我正在Windows 10上运行Python 2.7,numba 0.43.1和numpy 1.16.3

x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1

@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
        if x:
            g = np.multiply(eligibility,(1-expected))
        else:
            g = np.negative(np.multiply(eligibility,expected))
        gg = np.multiply(g,g)
        total_sq_grad_positive = np.add(total_sq_grad_positive,gg)
        #total_sq_grad_positive = np.where(divide_by_zero,total_sq_grad_positive, tsgp_temp)

        temp = np.multiply(learning_rate, g)
        temp2 = np.sqrt(total_sq_grad_positive)
        #temp2 = np.where(temp2 == 0,1,temp2 )
        temp2[temp2 == 0] = 1
        temp = np.divide(temp,temp2)
        positive_weight = np.add(positive_weight, temp)
        return [positive_weight, total_sq_grad_positive]

1 个答案:

答案 0 :(得分:1)

编辑:@ max9111似乎正确。不必要的临时数组是开销的来源。

对于函数的当前语义,似乎有两个无法避免的临时数组-返回值[positive_weight, total_sq_grad_positive]。但是,令我惊讶的是,您可能打算使用此功能来更新这两个输入数组。如果是这样,通过就地进行所有操作,我们将获得最大的加速。像这样:

import numba as nb
import numpy as np

x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1

@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
    for i in range(eligibility.shape[0]):
        if x:
            g = eligibility[i] * (1-expected)
        else:
            g = -(eligibility[i] * expected)
        gg = g * g
        total_sq_grad_positive[i] = total_sq_grad_positive[i] + gg

        temp = learning_rate * g
        temp2 = np.sqrt(total_sq_grad_positive[i])
        if temp2 == 0: temp2 = 1
        temp = temp / temp2
        positive_weight[i] = positive_weight[i] + temp

@nb.jit
def test(n, *args):
    for i in range(n): update_weight_from_post_post_jit(*args)

如果您不希望更新输入数组,则可以使用

positive_weight = positive_weight.copy()
total_sq_grad_positive = total_sq_grad_positive.copy()

并按照原始代码返回它们。这并没有那么快,但是仍然更快。


我不确定是否可以将其优化为“瞬时”。我对Java能够做到这一点感到有些惊讶,因为这对我来说似乎是一个相当复杂的功能,并且需要耗时的操作,例如sqrt

但是,您是否在调用此函数的函数上使用了nb.jit?像这样:

@nb.jit
def test(n):
    for i in range(n): update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate)

在我的计算机上,这将运行时间缩短了一半,这是有道理的,因为Python函数调用的开销非常大。