使用OnevsRest分类器时,是否可以知道每个分类器的准确性?

时间:2017-12-12 22:27:18

标签: python scikit-learn

我使用的是OnevsRest分类器。 我有一个包含21个类的数据集。我想知道每个分类器的准确性。

例如:

class1 vs(class2 + classx ... + class21)

的准确性

class2 vs(class3 + classx ... + class21)

的准确性

class21 vs(class1 + classx ... + class20)

的准确性

我怎么知道?

# Learn to predict each class against the other
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state))
y_score = classifier.fit(X_train, y_train).score(X_test, y_test)
print(y_score)

1 个答案:

答案 0 :(得分:0)

我认为这不是开箱即用的,你需要自己动手。

这是一些示例代码,我称之为原型,因为它没有经过严格测试!请记住,很难比较单类精度和元准确度(基于概率估计;在通过Platt缩放获得的SVM案例中)。

import numpy as np
from sklearn import datasets
from sklearn import svm
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import train_test_split

# Data
iris = datasets.load_iris()
iris_X = iris.data
iris_y = iris.target
X_train, X_test, y_train, y_test = train_test_split(
    iris_X, iris_y, test_size=0.5, random_state=0)

# Train classifier
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
    random_state=0))
y_score = classifier.fit(X_train, y_train).score(X_test, y_test)
print(y_score)

# Get all accuracies
classes = np.unique(y_train)

def get_acc_single(clf, X_test, y_test, class_):
    pos = np.where(y_test == class_)[0]
    neg = np.where(y_test != class_)[0]
    y_trans = np.empty(X_test.shape[0], dtype=bool)
    y_trans[pos] = True
    y_trans[neg] = False
    return clf.score(X_test, y_trans)                    # assumption: acc = default-scorer

for class_index, est in enumerate(classifier.estimators_):
    class_ = classes[class_index]
    print('class ' + str(class_))
    print(get_acc_single(est, X_test, y_test, class_))

输出:

0.8133333333333334
class 0
1.0
class 1
0.6666666666666666
class 2
0.9733333333333334