用Cython为NearestNeighbors编写的自定义指标仍然很慢

时间:2020-11-04 04:21:44

标签: python numpy cython

如本post中所述,使用纯Python编写的自定义指标会导致KNN运行缓慢。我推出了自己的用Cython编写的自定义指标,但是它仍然很慢。

这是我用Cython distances.pyx编写的自定义指标。这非常简单,因为它可以计算前三个维度的欧几里得距离,而其余维度计算的弦距离。 srcdst的形状均为12x1

import numpy
cimport numpy
from libc.math cimport sqrt


def euclidean_chordal(numpy.ndarray[numpy.float_t, ndim=1] src, dst):
    cdef int i
    cdef int dims = src.shape[0]
    cdef double diff
    cdef double euclidean_loss = 0.0
    cdef double chordal_loss = 0.0
    # Euclidean distance for point clouds and Chordal distance for frames
    for i in range(dims):
        diff = src[i] - dst[i]
        if i < 3:
            euclidean_loss += diff * diff
        else:
            chordal_loss += diff * diff
    euclidean_loss = sqrt(euclidean_loss)
    chordal_loss = sqrt(chordal_loss)

    return euclidean_loss + chordal_loss

这是setup.py

# Run python setup.py build_ext --inplace

from distutils.core import setup
from Cython.Build import cythonize

import numpy


setup(ext_modules=cythonize("distances.pyx"), include_dirs=[numpy.get_include()])

然后我用euclidean_chordal测试了sklearn.neighbors.NearestNeighbors。它仍然非常慢。有人知道这里发生了什么吗?

from distance.distances import euclidean_chordal
from sklearn.neighbors import NearestNeighbors

# src here is a point cloud with a shape of 76800 x 12
# dst here is a point cloud with a shape of 76800 x 12
neigh = NearestNeighbors(n_neighbors=1, metric=lambda a, b: euclidean_chordal(a, b))
neigh.fit(dst)
distances, indices = neigh.kneighbors(src, return_distance=True)

0 个答案:

没有答案