KNN模型的准确性得分(IRIS数据)

时间:2019-07-05 01:08:14

标签: python algorithm machine-learning scikit-learn knn

此基本KNN模型在IRIS数据上提高或稳定准确性得分请勿更改有显着差异)的关键因素可能是什么?

尝试

from sklearn import neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

iris = datasets.load_iris() 
X, y = iris.data[:, :], iris.target

Xtrain, Xtest, y_train, y_test = train_test_split(X, y)
scaler = preprocessing.StandardScaler().fit(Xtrain)
Xtrain = scaler.transform(Xtrain)
Xtest = scaler.transform(Xtest)

knn = neighbors.KNeighborsClassifier(n_neighbors=4)
knn.fit(Xtrain, y_train)
y_pred = knn.predict(Xtest)

print(accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

样本准确性得分

0.9736842105263158
0.9473684210526315
1.0
0.9210526315789473

分类报告

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        12
           1       0.79      1.00      0.88        11
           2       1.00      0.80      0.89        15

    accuracy                           0.92        38
   macro avg       0.93      0.93      0.92        38
weighted avg       0.94      0.92      0.92        38

样本混淆矩阵

[[12  0  0]
 [ 0 11  0]
 [ 0  3 12]]

2 个答案:

答案 0 :(得分:2)

在虹膜数据集中只有3个类别可用,即鸢尾花-Setosa,鸢尾花-Virginica和鸢尾花-Versicolor。

使用此代码。这使我97.78%的准确性

from sklearn import neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

iris = datasets.load_iris() 
X, y = iris.data[:, :], iris.target
Xtrain, Xtest, y_train, y_test = train_test_split(X, y, stratify = y, random_state = 0, train_size = 0.7)

scaler = preprocessing.StandardScaler().fit(Xtrain)
Xtrain = scaler.transform(Xtrain)
Xtest = scaler.transform(Xtest)

knn = neighbors.KNeighborsClassifier(n_neighbors=3)
knn.fit(Xtrain, y_train)
y_pred = knn.predict(Xtest)

print(accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

答案 1 :(得分:1)

我建议为k-NN调整k的值。由于鸢尾花是一个很小的数据集,并且平衡良好,因此我将执行以下操作:

For every value of `k` in range [2 to 10] (say)
  Perform a n-times k-folds crossvalidation (say n=20 and k=4)
    Store the Accuracy values (or any other metric)

根据平均值和方差绘制分数,并选择具有最佳k值的k值。交叉验证的主要目标是估计测试误差,然后根据您选择最终模型。会有一些差异,但应小于0.03或类似的值。这取决于数据集和折叠次数。一个好的过程是,为k的每个值制作一个所有20x4精度值的箱线图。选择下分位数与上分位数相交的k值,或者简单地说,就是精度(或其他度量值)的变化不大。

基于此选择k的值后,目标是使用该值来使用整个训练数据集构建最终模型。接下来,可以用来预测新数据。

另一方面,对于较大的数据集。制作一个单独的测试分区(如您在此处所做的那样),然后仅在训练集上调整k的值(使用交叉验证,而无需考虑测试集)。选择适当的k训练算法后,仅使用训练集进行训练。接下来,使用测试集报告最终值。永远不要根据测试集做出任何决定。

另一种方法是训练,验证和测试分区。使用训练集进行训练,并使用k的不同值来训练模型,然后使用验证分区进行预测并列出分数。根据此验证分区选择最佳分数。接下来,使用训练或训练+验证集来训练最终模型,该模型使用基于验证集选择的k的值。最后,取出测试集并报告最终分数。同样,切勿在其他任何地方使用测试集。

这些是适用于任何机器学习或统计学习方法的常规方法。

执行分区(训练,测试或交叉验证)时要注意的重要事项,请使用stratified sampling,以便在每个分区中类比率保持不变。

详细了解crossvalidation。在scikitlearn中很容易做到。如果使用R,则可以使用caret

要记住的主要事情是目标是训练一个泛化新数据或在新数据上表现良好的功能,而不仅仅是在现有数据上表现不佳。