我想使用GMM对经典虹膜数据集进行聚类。我从以下位置获取数据集:
https://gist.github.com/netj/8836201
到目前为止,我的程序如下:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture as mix
from sklearn.cross_validation import StratifiedKFold
def main():
data=pd.read_csv("iris.csv",header=None)
data=data.iloc[1:]
data[4]=data[4].astype("category")
data[4]=data[4].cat.codes
target=np.array(data.pop(4))
X=np.array(data).astype(float)
kf=StratifiedKFold(target,n_folds=10,shuffle=True,random_state=1234)
train_ind,test_ind=next(iter(kf))
X_train=X[train_ind]
y_train=target[train_ind]
gmm_calc(X_train,"full",y_train)
def gmm_calc(X_train,cov,y_train):
print X_train
print y_train
n_classes = len(np.unique(y_train))
model=mix(n_components=n_classes,covariance_type="full")
model.means_ = np.array([X_train[y_train == i].mean(axis=0) for i in
xrange(n_classes)])
model.fit(X_train)
y_predict=model.predict(X_train)
print cov," ",y_train
print cov," ",y_predict
print (np.mean(y_predict==y_train))*100
我遇到的问题是,当我尝试获取重合次数y_predict = y_train时,因为每次运行程序时,我都会得到不同的结果。例如:
首次运行:
full [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
full [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2
2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
0.0
第二次运行:
full [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
full [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0
0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
33.33333333333333
第三次运行:
full [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
full [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1
1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
98.51851851851852
因此,您可以看到每次运行的结果都不同。我在Internet上找到了一些代码:
https://scikit-learn.org/0.16/auto_examples/mixture/plot_gmm_classifier.html
但是,他们获得了完整的协方差,对于火车来说,其准确度约为82%。在这种情况下我怎么了?
谢谢
更新:我发现在互联网示例中,它使用GMM代替了新的GaussianMixture。我还发现,在示例中,GMM参数通过以下方式进行了监督: classifier.means_ = np.array([X_train [y_train == i] .mean(axis = 0) 为我在xrange(n_classes)])
我已将修改后的代码放在上面,但是每次运行它时,它仍然会更改结果,但是对于库GMM来说,它不会发生。
答案 0 :(得分:2)
1)GMM分类器使用Expectation–maximization algorithm来拟合高斯模型的混合:高斯分量随机地以数据点为中心,然后算法移动它们直到收敛到局部最优。由于随机初始化,每次运行的结果可能不同。因此,您还必须使用random_state
的{{1}}参数(或尝试设置更多的初始化次数GMM
,并期待更多类似的结果。)
2)发生准确性问题是因为n_init
(与GMM
一样)适合kmeans
高斯并报告了每个点所属的高斯分量“数字”;该数字在每次运行中都不同。您可以在预测中看到群集是相同的,但是它们的标签被交换了:(1,2,0)->(1,0,2)->(0,1,2),最后一个组合与适当的课程,因此您获得98%的分数。如果将它们标绘,您会发现在这种情况下高斯人本身往往会保持不变,例如
您可以使用许多clustering metrics来考虑到这一点:
n
https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html中的绘图代码请注意,不同版本的代码不同,如果使用旧版本,则需要替换>>> [round(i,5) for i in (metrics.homogeneity_score(y_predict, y_train),
metrics.completeness_score(y_predict, y_train),
metrics.v_measure_score(y_predict,y_train),
metrics.adjusted_rand_score(y_predict, y_train),
metrics.adjusted_mutual_info_score(y_predict, y_train))]
[0.86443, 0.8575, 0.86095, 0.84893, 0.85506]
函数:
make_ellipses
答案 1 :(得分:0)
您的查询很晚。可能对其他人有益。 正如@hellpanderr 所发布的,在 GMM 中使用“random_state=1”