从sklearn
使用自定义距离指标函数进行聚类算法时,我遇到了性能瓶颈。
Run Snake Run显示的结果如下:
显然问题是dbscan_metric
功能。该功能看起来非常简单,我不太清楚加速它的最佳方法是:
def dbscan_metric(a,b):
if a.shape[0] != NUM_FEATURES:
return np.linalg.norm(a-b)
else:
return np.linalg.norm(np.multiply(FTR_WEIGHTS, (a-b)))
任何关于导致它如此缓慢的想法都会非常感激。
答案 0 :(得分:1)
我不熟悉函数的作用 - 但是有可能重复计算吗?如果是这样,你可以记住这个功能:
cache = {}
def dbscan_metric(a,b):
diff = a - b
if a.shape[0] != NUM_FEATURES:
to_calc = diff
else:
to_calc = np.multiply(FTR_WEIGHTS, diff)
if not cache.get(to_calc): cache[to_calc] = np.linalg.norm(to_calc)
return cache[to_calc]