scikit学习交叉验证classification_report

时间:2016-05-09 19:58:10

标签: python scikit-learn

我希望每个类标签都有度量标准,并且scikit中的交叉验证需要汇总混淆矩阵。

我写了一个方法,对scikit进行交叉验证学习,总结混淆矩阵,并存储所有预测标签。然后,它调用scikit学习方法来打印指标。

以下代码应与最近的scikit learn安装一起运行,您可以使用任何数据集进行测试。

在进行StratifiedKFold交叉验证时,是否低于收集汇总cmclassification_report的正确方法?

from sklearn import metrics
from sklearn.cross_validation import StratifiedKFold
import numpy as np

def customCrossValidation(self, X, y, classifier, n_folds=10, shuffle=True, random_state=0):
    ''' Perform a cross validation and print out the metrics '''
    skf = StratifiedKFold(y, n_folds=n_folds, shuffle=shuffle, random_state=random_state)
    cm = None
    y_predicted_overall = None
    y_test_overall = None
    for train_index, test_index in skf:
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        classifier.fit(X_train, y_train)
        y_predicted = classifier.predict(X_test)
        # collect the y_predicted per fold
        if y_predicted_overall is None:
            y_predicted_overall = y_predicted
            y_test_overall = y_test
        else: 
            y_predicted_overall = np.concatenate([y_predicted_overall, y_predicted])
            y_test_overall = np.concatenate([y_test_overall, y_test])
        cv_cm = metrics.confusion_matrix(y_test, y_predicted)
        # sum the cv per fold
        if cm is None:
            cm = cv_cm
        else:
            cm += cv_cm
    print (metrics.classification_report(y_test_overall, y_predicted_overall, digits=3))
    print (cm)

0 个答案:

没有答案