绘制clusterplot时发生Python valueerror(形状未对齐)

时间:2019-11-30 10:00:00

标签: python list

我正在尝试绘制高斯混合模型的结果,但是由于数组的形状不同,显示错误。这是代码:

import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
from sklearn import model_selection

K = 3
cov_type = 'full' # e.g. 'full' or 'diag'

# define the initialization procedure (initial value of means)
initialization_method = 'random'#  'random' or 'kmeans'

reps = 1

# Fit Gaussian mixture model
gmm = GaussianMixture(n_components=K, covariance_type=cov_type, n_init=reps, 
                  tol=1e-6, reg_covar=1e-6, init_params=initialization_method).fit(X)
cls = gmm.predict(X)    
# extract cluster labels
cds = gmm.means_        
# extract cluster centroids (means of gaussians)
covs = gmm.covariances_
# extract cluster shapes (covariances of gaussians)
if cov_type.lower() == 'diag':
    new_covs = np.zeros([K,M,M])    

    count = 0    
    for elem in covs:
        temp_m = np.zeros([M,M])
        new_covs[count] = np.diag(elem)
        count += 1

    covs = new_covs

# Plot results:
figure(figsize=(14,9))
clusterplot(X, clusterid=cls, centroids=cds, y=y, covars=covs)
show()

我收到以下错误消息

bp = np.dot(v, np.dot(d, ap)) + np.tile(mean, (1, ap.shape[1]))

ValueError: shapes (8,8) and (2,100) not aligned: 8 (dim 1) != 2 (dim 0)

错误似乎是clusterplot的倒数第二行,但是我不知道为什么。任何帮助将不胜感激,谢谢!

0 个答案:

没有答案