分配给数组时,奇怪的numba行为

时间:2018-04-24 19:02:26

标签: numpy numba

我有一个函数我正在使用@jit(nopython = True)。

在它内部有一个循环,可以执行大量的操作,计算相关性,然后将其分配给预先分配的输出数组。目标数组和相关都具有相同的类型(np.float32),但由于某种原因,赋值使函数占用100倍。

为了让事情变得更奇怪,如果我改为指定一个无意义的浮点数np.float32(i * 1.01)而不是我的相关值,那么该函数将以适当的速度运行。

鉴于一切都是同一类型,它们应该以相同的速度运行吗?

corrs = np.zeros(a.shape[0], dtype=np.float32)

for i in range(lb, a.shape[0]):
    #a bunch of calculations happens here

    correl = np.float32(covar/(a_std*b_std))

    testval = np.float32(i*1.01)

    #doing this makes the function take FOREVER
    #corrs[i] = correl

    #but doing this runs very quickly, even though it is also a np.float32
    #corrs[i] = testval

这是一个可运行的例子。我添加了一个名为“assign”的参数,如果true将分配我想要分配的内容,如果为false将分配我无用的测试值。

@jit(nopython=True)
def hist_corr_loop(a, b, lb = 1000, assign=True):


flb = np.float32(lb)

a_mu, b_mu = a[0], b[0]


for i in range(1, lb):
    a_mu+=a[i]
    b_mu+=b[i]



a_mu = a_mu/flb
b_mu = b_mu/flb


a_var, b_var = np.float32(0.0), np.float32(0.0)
for i in range(lb):
    a_var += np.square(a[i] - a_mu)
    b_var += np.square(b[i] - b_mu)

a_var = a_var/flb
b_var = b_var/flb


corrs = np.zeros(a.shape[0], dtype=np.float32)


for i in range(lb, a.shape[0]):

    #calculate new means and stdevs
    _a_mu = a_mu
    _b_mu = b_mu

    a_mu = _a_mu + (a[i] - a[i-lb])/flb
    b_mu = _b_mu + (b[i] - b[i-lb])/flb

    a_var += (a[i] - a[i-lb])*(a[i] - a_mu + a[i-lb] - _a_mu)/flb
    b_var += (b[i] - b[i-lb])*(b[i] - b_mu + b[i-lb] - _b_mu)/flb

    a_std = np.sqrt(a_var)#**0.5
    b_std = np.sqrt(b_var)#**0.5

    covar = np.float32(0.0)

    for j in range(i-lb+1,i+1):

        covar += (a[j] - a_mu)*(b[j] - b_mu)

    covar = covar/flb

    correl = np.float32(covar/(a_std*b_std))

    testval = np.float32(i*1.01)

    if assign:
        corrs[i] = correl

    else:
        corrs[i] = testval

return corrs

运行:

n = 10000000
a = np.random.random(n)
b = np.random.random(n)

%timeit hist_corr_loop(a,b,1000, True)
%timeit hist_corr_loop(a,b, 1000, False)

我得到了

%timeit hist_corr_loop(a,b,1000, True)
10.5 s ± 52.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit hist_corr_loop(a,b, 1000, False)
220 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

10秒vs 220 ms。

0 个答案:

没有答案