numba输出的差异

时间:2016-06-28 10:59:10

标签: python numpy numba

我在学习工作中实施了基本的近邻搜索。 事实是,基本的numpy实现运行良好,但只是添加' @ jit'装饰器(在Numba中编译),输出是不同的(它由于一些未知的原因它最终复制了一些邻居......)

这是基本算法:

import numpy as np
from numba import jit

@jit(nopython=True)
def knn(p, points, k):
    '''Find the k nearest neighbors (brute force) of the point p
    in the list points (each row is a point)'''

    n = p.size  # Lenght of the points
    M = points.shape[0]  # Number of points
    neighbors = np.zeros((k,n))
    distances = 1e6*np.ones(k)

    for i in xrange(M):
        d = 0
        pt = points[i, :]  # Point to compare
        for r in xrange(n):  # For each coordinate
            aux = p[r] - pt[r]
            d += aux * aux
        if d < distances[k-1]:  # We find a new neighbor
            pos = k-1
            while pos>0 and d<distances[pos-1]:  # Find the position
                pos -= 1
            pt = points[i, :]
            # Insert neighbor and distance:
            neighbors[pos+1:, :] = neighbors[pos:-1, :]
            neighbors[pos, :] = pt
            distances[pos+1:] = distances[pos:-1]
            distances[pos] = d

    return neighbors, distances

进行测试:

p = np.random.rand(10)
points = np.random.rand(250, 10)
k = 5
neighbors = knn(p, points, k)

没有@jit装饰器,就可以得到正确答案:

In [1]: distances
Out[1]: array([ 0.3933974 ,  0.44754336,  0.54548715,  0.55619749,  0.5657846 ])

但Numba汇编给出了奇怪的输出:

Out[2]: distances
Out[2]: array([ 0.3933974 ,  0.44754336,  0.54548715,  0.54548715,  0.54548715])

有人可以帮忙吗?我不知道为什么会这样......

谢谢你。

1 个答案:

答案 0 :(得分:1)

我认为问题在于,当那些切片重叠而不是没有切片时,Numba正在处理将一个切片写入另一个切片。我不熟悉numpy的内部结构,但也许有一些特殊的逻辑可以处理像这样的易失性内存操作,这些内容在Numba中并不存在。更改以下行,jit装饰器的结果与普通的python版本保持一致:

neighbors[pos+1:, :] = neighbors[pos:-1, :].copy()
...
distances[pos+1:] = distances[pos:-1].copy()