我正在使用scikit-learn进行GMM训练,并试图通过循环整数列表来改变混合成分的数量。但是当我打印出我生成的模型时,我只会得到3个混合组件,或者我列出的最后一项。
这是我的代码:
from sklearn.mixture import GMM
class_names = ['name1','name2','name3']
covs = ['spherical', 'diagonal', 'tied', 'full']
num_comp = [1,2,3]
models = {}
for c in class_names:
models[c] = dict((covar_type,GMM(n_components=num,
covariance_type=covar_type, init_params='wmc',n_init=1, n_iter=10)) for covar_type in covs for num in num_comp)
print models
有人可以帮忙吗? 非常感谢提前!
答案 0 :(得分:4)
这是因为在表达式中:
dict((covar_type,GMM(n_components=num,
covariance_type=covar_type, init_params='wmc',n_init=1, n_iter=10)) for covar_type in covs for num in num_comp)
您在所有迭代中使用相同的covar_type
作为键,从而重写相同的元素。
如果我们以更易读的方式编写代码,就会发生这种情况:
data = dict()
for covar_type in covs:
for num in num_comp:
# covar_type is the same for all iterations of this loop
# hence only the last one "survives"
data[covar_type] = GMM(...)
如果要保留所有值,则应使用值列表而不是单个值或更改密钥。
对于值列表:
data = dict()
for covar_type in covs:
data[covar_type] = values = []
for num in num_comp:
values.append(GMM(...))
对于不同的键:
data = dict()
for covar_type in covs:
for num in num_comp:
data[(covar_type, num)] = GMM(...)