我正在MeanShift()
模块(here are the docs)中运行名为sklearn.cluster
的群集算法。我正在处理的对象在三维空间中分布有310,057个点。我正在运行它的计算机总共有128Gb的ram,所以当我得到以下错误时,我很难相信我实际上正在使用它。
[user@host ~]$ python meanshifttest.py
Traceback (most recent call last):
File "meanshifttest.py", line 13, in <module>
ms = MeanShift().fit(X)
File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 280, in fit
cluster_all=self.cluster_all)
File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 99, in mean_shift
bandwidth = estimate_bandwidth(X)
File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 45, in estimate_bandwidth
d, _ = nbrs.kneighbors(X, return_distance=True)
File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/neighbors/base.py", line 313, in kneighbors
return_distance=return_distance)
File "binary_tree.pxi", line 1313, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:10007)
File "binary_tree.pxi", line 595, in sklearn.neighbors.kd_tree.NeighborsHeap.__init__ (sklearn/neighbors/kd_tree.c:4709)
MemoryError
我正在运行的代码如下所示:
from sklearn.cluster import MeanShift
import asciitable
import numpy as np
import time
data = asciitable.read('./multidark_MDR1_FOFID85000000000_ParticlePos.csv',delimiter=',')
x = [data[i][2] for i in range(len(data))]
y = [data[i][3] for i in range(len(data))]
z = [data[i][4] for i in range(len(data))]
X = np.array(zip(x,y,z))
t0 = time.time()
ms = MeanShift().fit(X)
t1 = time.time()
print str(t1-t0) + " seconds."
labels = ms.labels_
print set(labels)
有人会对发生的事情有任何想法吗?不幸的是我不能切换聚类算法,因为这是我发现的唯一一个除了不接受链接长度/ k个簇/先验信息之外还做得很好的算法。
提前致谢!
**更新: 我更多地查看了文档,并说明了以下内容:
可扩展性:
因为此实现使用扁平内核和
一个Ball Tree查找每个内核的成员,复杂性将是
较低维度的O(T * n * log(n)),n个样本数量 和T点数。在更高的维度上,复杂性将是 倾向于O(T * n ^ 2)。通过使用更少的种子可以提高可伸缩性,例如通过使用
get_bin_seeds函数中较高的min_bin_freq值。请注意,estimate_bandwidth函数的可扩展性远低于
平均移位算法,如果使用它将成为瓶颈。
这似乎有道理,因为如果你仔细看一下错误就会抱怨estimate_bandwidth。这是否表示我只是在为算法使用太多粒子?
答案 0 :(得分:4)
从错误信息判断,我怀疑它正在尝试计算点之间的所有成对距离,这意味着它需要310057²浮点数或716GB RAM。
您可以通过向bandwidth
构造函数提供显式MeanShift
参数来禁用此行为。
这可以说是一个错误;考虑为它提交错误报告。 (包括我自己在内的scikit-learn工作人员最近一直在努力摆脱各种地方过于昂贵的距离计算,但显然没有人看过手段。)
编辑:上面的计算是3倍,但内存使用量确实是二次的。我只是修改了scikit-learn的开发版本。