我正在使用gridsearchCV找到BIRCH的最佳参数,我的代码是:
RAND_STATE=50 # for reproducibility and consistency
folds=3
k_fold = KFold(n_splits=folds, shuffle=True, random_state=RAND_STATE)
hyperparams = { "branching_factor": [50,100,200,300,400,500,600,700,800,900],
"n_clusters": [5,7,9,11,13,17,21],
"threshold": [0.2,0.3,0.4,0.5,0.6,0.7]}
birch = Birch()
def sil_score(ndata):
labels = ensemble.predict(ndata)
score = silhouette_score(ndata, labels)
return score
sil_scorer = make_scorer(sil_score)
ensemble = GridSearchCV(estimator=birch,param_grid=hyperparams,scoring=sil_scorer,cv=k_fold,verbose=10,n_jobs=-1)
ensemble.fit(x)
print ensemble
best_parameters = ensemble.best_params_
print best_parameters
best_score = ensemble.best_score_
print best_score
然而输出给我一个错误:
我很困惑为什么当我已经说明了在sil_score函数中得分所需的参数时,得分值正在寻找4个参数。
答案 0 :(得分:2)
您的评分功能不正确。语法应为sil_score(y_true,y_pred)
,其中y_true是基本事实标签,y_pred
是预测标签。此外,您无需使用评分函数中的整体对象单独预测标签。同样在你的情况下,直接使用silhouette_score
作为评分函数更有意义,因为你正在调用你的整体来预测评分函数内部的标签,这根本不是必需的。只需传递silhouette_score
作为评分函数,GridSearchCV将负责预测自己的评分。
Here is an example如果你想知道它是如何运作的。