我是numba的用户,有人可以告诉我为什么numpy数组这么慢,这是一个例子:
def pairwise_python2(X):
n_samples = X.shape[0]
result = np.zeros((n_samples, n_samples), dtype=X.dtype)
for i in xrange(X.shape[0]):
for j in xrange(X.shape[0]):
result[i, j] = np.sqrt(np.sum((X[i, :] - X[j, :]) ** 2))
return result
%timeit pairwise_python2(X)
1个循环,最佳3:每循环18.2秒
from numba import double
from numba.decorators import jit, autojit
pairwise_numba = autojit(pairwise_python)
%timeit pairwise_numba(X)
1个循环,最佳3:每循环13.9秒
似乎jit和cpython版本之间没有区别,我错了吗?
答案 0 :(得分:1)
你正在计算numpy内存分配。 X [i,:] - X [j,:]生成一个新的形状矩阵(n_samples,n_samples),方形运算也是如此。请尝试以下内容:
def pairwise_python2(X):
n_samples = X.shape[0]
result = np.empty((n_samples, n_samples), dtype=X.dtype)
temp = np.empty((n_samples,), dtype=X.dtype)
for i in xrange(n_samples):
slice = X[i,:]
for j in xrange(n_samples):
result[i,j] = np.sqrt(np.sum(np.power(np.subtract(slice,X[j,:],temp),2.0,temp)))
return result
Numba并没有添加很多内容,因为你在numpy中进行所有操作(虽然它会加速循环迭代,这可以在你的计时函数中看到)。
答案 1 :(得分:1)
新版numba
支持numpy
数组切片和 np.sqrt()
功能。所以,这个问题可以关闭。