numba numpy数组切片太慢了?

时间:2013-11-12 01:14:54

标签: numba

我是numba的用户,有人可以告诉我为什么numpy数组这么慢,这是一个例子:

def pairwise_python2(X):

    n_samples = X.shape[0]

    result = np.zeros((n_samples, n_samples), dtype=X.dtype)

    for i in xrange(X.shape[0]):

        for j in xrange(X.shape[0]):

            result[i, j] = np.sqrt(np.sum((X[i, :] - X[j, :]) ** 2))

    return result

%timeit pairwise_python2(X)

1个循环,最佳3:每循环18.2秒

from numba import double

from numba.decorators import jit, autojit

pairwise_numba = autojit(pairwise_python)

%timeit pairwise_numba(X)

1个循环,最佳3:每循环13.9秒

似乎jit和cpython版本之间没有区别,我错了吗?

2 个答案:

答案 0 :(得分:1)

你正在计算numpy内存分配。 X [i,:] - X [j,:]生成一个新的形状矩阵(n_samples,n_samples),方形运算也是如此。请尝试以下内容:

def pairwise_python2(X):
    n_samples = X.shape[0]
    result = np.empty((n_samples, n_samples), dtype=X.dtype)
    temp = np.empty((n_samples,), dtype=X.dtype)
    for i in xrange(n_samples):
        slice = X[i,:]
        for j in xrange(n_samples):
            result[i,j] = np.sqrt(np.sum(np.power(np.subtract(slice,X[j,:],temp),2.0,temp)))
    return result

Numba并没有添加很多内容,因为你在numpy中进行所有操作(虽然它会加速循环迭代,这可以在你的计时函数中看到)。

答案 1 :(得分:1)

新版numba支持numpy数组切片和 np.sqrt() 功能。所以,这个问题可以关闭。