numba的高效方形欧氏距离代码是否比numpy的高效对应代码慢?

时间:2018-06-04 07:48:52

标签: python numpy numba

我从(Why this numba code is 6x slower than numpy code?)修改最有效的代码,以便它可以处理x1为(n,m)

@nb.njit(fastmath=True,parallel=True)
def euclidean_distance_square_numba_v5(x1, x2):
    res = np.empty((x1.shape[0], x2.shape[0]), dtype=x2.dtype)
    for a_idx in nb.prange(x1.shape[0]):
        for o_idx in range(x2.shape[0]):
            val = 0.
            for i_idx in range(x2.shape[1]):
                tmp = x1[a_idx, i_idx] - x2[o_idx, i_idx]
                val += tmp * tmp 
            res[a_idx, o_idx] = val 
    return res

然而,更高效的numpy版本仍然没有效率:

def euclidean_distance_square_einsum(x1, x2):
    return np.einsum('ij,ij->i', x1, x1)[:, np.newaxis] + np.einsum('ij,ij->i', x2, x2) - 2*np.dot(x1, x2.T)

输入为

a = np.zeros((1000000,512), dtype=np.float32)
b = np.zeros((100, 512), dtype=np.float32)

我得到的时间是numba代码为2.4723422527313232,numpy代码为0.8260958194732666。

1 个答案:

答案 0 :(得分:2)

是的,这是预期的。

你必须要注意的第一件事是:dot-product是numpy-version的工作马,这里是稍微小一点的数组:

>>> def only_dot(x1, x2):
        return - 2*np.dot(x1, x2.T)

>>> a = np.zeros((1000,512), dtype=np.float32)
>>> b = np.zeros((100, 512), dtype=np.float32)

>>> %timeit(euclidean_distance_square_einsum(a,b))
6.08 ms ± 312 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit(euclidean_only_dot(a,b))
5.25 ms ± 330 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

即。 85%的时间都花在了它上面。

当你看到你的numba代码时,看起来像一个有点奇怪/不寻常/更复杂的矩阵 - 矩阵乘法版本 - 人们可以看到例如相同的三个循环。

所以基本上,你试图击败其中一个最优化的算法。这是例如somebody trying to do it and failing。我的安装使用的是英特尔的MKL版本,它必须比默认实现更复杂,可以找到here

有时候,在完全有趣之后,人们不得不承认自己的重新发明的车轮"并不像最先进的车轮那么好......但只有这样才能真正体会它的表现。