我有一段代码,它生成高斯分布并从中采样数据:
input = pd.read_csv("..\\data\\input.txt", sep=",", header=None).values
gmm = GMM(n_components=5).fit(input)
sampled = gmm.sample(input.shape[0], random_state=42)
original_label = gmm.predict(input)
generated_label = gmm.predict(sampled)
return sampled
当我检查original_label和generated_label时,每个群集中的样本数量不同。
The number of elements in original_label:
Cluster 1:0
Cluster 2:1761
Cluster 3:2024
Cluster 4:769
Cluster 5:0
The number of elements in generated_label:
Cluster 1:0
Cluster 2:1273
Cluster 3:739
Cluster 4:1140
Cluster 5:1402
我想从gmm中采样原始输入的相同分布。这里,采样和原始数据的集群之间存在很大差异。你能帮我解决一下吗?
答案 0 :(得分:0)
高斯混合模型是一种 soft 聚类方法。每个对象都属于每个集群,只是程度不同。
如果要对软簇密度求和,它们应该更紧密地匹配。 (我建议你验证一下。群集5的巨大差异可能表明sklearn存在问题。)
由于群集重叠,生成满足GMM密度模型以及预测的硬标签的硬聚类通常是不可满足的。这证明了“艰难”。标签不适用于GMM的基本假设。