GridSearchCV:训练集的混淆矩阵

时间:2020-07-03 14:49:07

标签: python confusion-matrix gridsearchcv

我在这个话题上很新,希望任何人都能帮助我。
我将这个GridSearchCV与打印结果一起使用,我想使用训练集来创建一个混淆矩阵。

我该怎么做?

param_grid = [
  {'C': [0.01, 0.1, 0.5, 1., 10., 50., 100., 1000.], 'kernel': ['linear']},
  {'C': [0.01, 0.1, 0.5, 1., 10., 50., 100., 1000.], 'gamma': [0.0001, 0.001, 0.01, 0.1, 0.5, 1., 10., 50., 100.], 'kernel': ['rbf']},
 ]

def tn(y_true, y_pred): 
    return confusion_matrix(y_true, y_pred)[0, 0]
def fp(y_true, y_pred): 
    return confusion_matrix(y_true, y_pred)[0, 1]
def fn(y_true, y_pred): 
    return confusion_matrix(y_true, y_pred)[1, 0]
def tp(y_true, y_pred): 
    return confusion_matrix(y_true, y_pred)[1, 1]
scorer = {'tp': make_scorer(tp), 'tn': make_scorer(tn),
           'fp': make_scorer(fp), 'fn': make_scorer(fn),
          'balanced_accuracy': make_scorer(balanced_accuracy_score)
          }

svm_clf_gs = GridSearchCV(svm.SVC(class_weight="balanced"), param_grid, cv=15, scoring=scorer, refit='balanced_accuracy', n_jobs=6)

svm_clf_gs.fit(descriptors, activities)

model = svm_clf_gs.best_estimator_
cv_results = svm_clf_gs.cv_results_

print(cv_results)
print("Best Params:", svm_clf_gs.best_params_)

0 个答案:

没有答案