如何加快numpy张量*张量操作

时间:2020-01-24 16:58:33

标签: python numpy optimization numba

我的代码中有一个瓶颈,那就是numpy 3d数组乘以* operator与numpy 3d数组。
我想使用numba @njit或@jit装饰器来加速程序的这一部分,但是它使性能降低了2倍。
代码的慢部分:

@numba.jit
def mat_mul_and_sum(img1, img2, alpha):
    return img1*(1-alpha) + img2*alpha 

img1,img2和alpha是具有相同形状的3d np.array。
可以加快这一行代码的速度吗?

2 个答案:

答案 0 :(得分:4)

一个选项实际上是以应该使用的方式使用numba(而不仅仅是应用装饰器)。但是,对于您的特定功能,您可以使用numexpr软件包使用多核渲染。


import numexpr as ne

def mat_mul_and_sum_numexpr(a, b, alpha):
    return ne.evaluate('a*(1-alpha) + b*alpha')

使用其他答案中的时间:

In [11]: %timeit mat_mul_and_sum(img1, img2, alpha)
21.6 ms ± 955 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [12]: %timeit mat_mul_and_sum2(img1, img2, alpha)
6.35 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [13]: %timeit mat_mul_and_sum_numexpr(img1, img2, alpha)
4.22 ms ± 54.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [14]: np.allclose(mat_mul_and_sum(img1, img2, alpha), mat_mul_and_sum_numexpr(img1, img2, alpha))
Out[14]: True

通过numba的并行化,您也许可以挤出一些额外的性能,但是通常使用numexpr可以大大提高性能,而无需重写任何代码。

答案 1 :(得分:3)

如果按以下方式展开循环,则对于数组大小为(100,100,100)的numba,其速度是纯numpy版本的两倍,这可能是由于不需要分配中间数组的事实: / p>

import numpy as np
import numba as nb

def mat_mul_and_sum(img1, img2, alpha):
    return img1*(1-alpha) + img2*alpha


@nb.jit
def mat_mul_and_sum2(img1, img2, alpha):
    NI, NJ, NK = img1.shape
    out = np.empty((NI, NJ, NK))

    for i in range(NI):
        for j in range(NJ):
            for k in range(NK):
                out[i,j,k] = img1[i,j,k] * (1.0 - alpha[i,j,k]) + img2[i,j,k] * alpha[i,j,k]

    return out

然后进行测试:

N = 100
img1 = np.random.normal(size=(N, N, N))
img2 = np.random.normal(size=(N, N, N))
alpha = np.random.normal(size=(N, N, N))

A = mat_mul_and_sum(img1, img2, alpha)
B = mat_mul_and_sum2(img1, img2, alpha)

np.allclose(A,B) #True

%timeit mat_mul_and_sum(img1, img2, alpha)
# 4.6 ms ± 44.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit mat_mul_and_sum2(img1, img2, alpha)
# 2.4 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

更新: 您也可以尝试将装饰器更改为nb.jit(parallel=True),然后将外部循环替换为for i in nb.prange(NI):,这在我的机器上将结果从timeit降低到1.37毫秒。每个机器的时间以及其他输入时间肯定会有所不同。