显示错误分类的实例

时间:2016-01-24 13:37:28

标签: python scikit-learn

我正在使用Scikit-learn构建SVM分类器......并且在运行分类器时...我想通过检查错误分类的实例并试图找出错误分类背后的原因来提高分类器的准确性。 .. 那么有没有办法显示错误分类的实例?

2 个答案:

答案 0 :(得分:5)

  

有没有办法显示错误分类的实例?

是的,你需要在这里和那里做一些索引。下面是一个示例,但技术细节将取决于分类器的输入和输出。

简单的情况是输出是单个值,因此您可以轻松地比较实例是否已正确分类。例如,让我们收集一些数据并训练二进制分类器:

>>> from sklearn import cross_validation, datasets, svm
>>> X, y = datasets.make_classification()
>>> X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y)
>>> clf = svm.LinearSVC().fit(X_train, y_train)
>>> y_pred = clf.predict(X_test)

您可以直接比较y_testy_pred,因为输出是单个值。如果您正在训练多类模型,那么您将无法进行简单的比较,而是应该逐个比较。

>>> misclassified_samples = X_test[y_test != y_pred]

如果需要,您也可以将布尔掩码转换为索引。

>>> import numpy as np
>>> np.flatnonzero(y_test != y_pred)
array([ 0, 20, 22])

答案 1 :(得分:1)

我假设您使用线性SVM。如果没有,这是非常相似的程序。

from sklearn.svm import LinearSVC
X_train=your_train_data
y_train=your_train_lables
X_test=your_test_data #should be around 30% of you your data
y_test=your_test_labels
svm = LinearSVC()
svm.fit(X_train, y_train)
for item, label in zip(X_test, y_test):
    result = svm.predict([item])
    if result != label:
        print "predicted label %s, but true label is %s" % (result, label)

这将打印出分类器对测试数据所做的每一个错误。