更新:有没有人知道如何在Cython中编写下面的While循环?我需要它非常快。我对Cython没有经验,但是从我可以收集到的内容中,我需要用可以在C中编译的代码替换NumPy调用。
有没有办法让下面的代码更快,例如,通过替换NumPy函数,以便Numba可以在nopython模式下编译while循环?输入变量(行,列)是1D列表/整数数组。
.so
可以对以下数据集进行基准测试:
import numpy as np
def cluster(rows, cols):
member_sets = []
centroids = []
while cols.any():
un, coun = np.unique(cols, return_counts=True)
centroid = un[np.argmax(coun)]
centroids.append(centroid)
member_set = rows[cols == centroid]
member_sets.append(member_set)
rows = rows[np.in1d(cols, member_set, invert=True)]
cols = cols[np.in1d(cols, member_set, invert=True)]
cols = cols[np.in1d(rows, member_set, invert=True)]
rows = rows[np.in1d(rows, member_set, invert=True)]
return member_sets, centroids
我在笔记本电脑上得到了什么:
from sklearn.neighbors import NearestNeighbors
from sklearn.datasets import make_blobs
X, y = make_blobs(1000, random_state=1)
nbrs = NearestNeighbors(metric='euclidean', radius=3)
nbrs.fit(X)
adj_mat = nbrs.radius_neighbors_graph(X)
rows = adj_mat.nonzero()[0]
cols = adj_mat.nonzero()[1]
是否可以将速度提高一个数量级(<1秒)或更好?
以下是分析的结果,数据集大小为10,000。看起来最昂贵的NumPy电话是argsort。
import timeit
timeit.timeit("cluster(rows, cols)", globals=globals(), number=100)
Output: 12.490656665991992