sklearn.mixture.DPGMM:只有一个集群?

时间:2015-05-19 19:30:00

标签: python machine-learning scikit-learn cluster-analysis

我有一个数据集,我在sklearn中使用Dirichlet过程高斯混合模型得到了奇怪的结果。

import sklearn.mixture, pandas
import numpy as np
from matplotlib import pyplot as plt

A = np.random.normal(0, .5,200)
B =  np.random.normal(2, .5,200)
X = np.array(np.asmatrix(np.concatenate([A, B])).T)

dpgmm = sklearn.mixture.DPGMM( n_iter=1000, verbose=True, 
n_components=2, alpha=1, covariance_type = 'full')
dpgmm.fit(X)
print(dpgmm.means_)
print(dpgmm.weights_)

要绘制它我正在使用:

def pdf(x, mean, var, weight):
    p = np.exp(-(x - mean) ** 2 / (2 * var)) / np.sqrt(2 * np.pi * var)
    return p * weight

def plot_gmm(model, data):
    colors = np.array(['r', 'g', 'b', 'm', 'c'])
    assignments = model.predict(data)
    clusters = np.unique(assignments)
    fig, ax = plt.subplots()
    ax.scatter(data[:, 0], np.zeros(data.size), c=colors[assignments], alpha=0.1)
    means = model.means_[clusters]
    variances = np.array([x.ravel() for x in np.array(model._get_covars())[clusters]])
    weights = model.weights_[clusters]
    for i in range(means.size):
        x = np.linspace(-0.1, 0.7, 100)
        y = pdf(x, means[i], variances[i], weights[i])
        ax.plot(x, y, color=colors[i], linewidth=2)

使用DPGMM,我可以在n_components和alpha的一系列值中获得单个组件(我无法发布图像,因为这是我关于SO的第一个问题)。另一方面,如果我训练一个带有2个组件的香草GMM,我得到的东西更符合我的期望。

gmm = sklearn.mixture.GMM( n_iter=1000,  n_components=2, covariance_type='full')
gmm.fit(X)
plot_gmm(gmm, X)
plt.show()

我对DPGMM有什么看法吗?

0 个答案:

没有答案