来自sklearn的GMM意外表现不佳

时间:2014-12-11 21:49:51

标签: python scikit-learn classification cluster-analysis

我尝试使用scikitlearn的DPGMM分类器来模拟一些模拟数据,但我的性能很差。以下是我使用的示例:

from sklearn import mixture
import numpy as np
import matplotlib.pyplot as plt
clf = mixture.DPGMM(n_components=5, init_params='wc')
s = 0.1
a = np.random.normal(loc=1, scale=s, size=(1000,))
b = np.random.normal(loc=2, scale=s, size=(1000,))
c = np.random.normal(loc=3, scale=s, size=(1000,))
d = np.random.normal(loc=4, scale=s, size=(1000,))
e = np.random.normal(loc=7, scale=s*2, size=(5000,))
noise = np.random.random(500)*8 
data = np.hstack([a,b,c,d,e,noise]).reshape((-1,1))
clf.means_ = np.array([1,2,3,4,7]).reshape((-1,1))
clf.fit(data)
labels = clf.predict(data)
plt.scatter(data.T, np.random.random(len(data)), c=labels, lw=0, alpha=0.2)
plt.show()

我认为这正是高斯混合模型可以解决的问题。我尝试过玩alpha,使用gmm而不是dpgmm,改变起始组件的数量等等。我似乎无法获得可靠和准确的分类。我有什么东西不见了吗?还有另一种更合适的模式吗?

1 个答案:

答案 0 :(得分:0)

因为你没有足够长的迭代时间来收敛

检查

的值
clf.converged_

并尝试将n_iter增加到1000

但请注意,DPGMM仍然在此数据集上失败恕我直言,最终将群集数量减少到2个。