我正在尝试为K
找到最佳的KNeighborsClassifier
值。
这是我的iris
数据集的代码:
k_loop = np.arange(1,30)
k_scores = []
for k in k_loop:
knn = KNeighborsClassifier(n_neighbors=k)
cross_val = cross_val_score(knn, X, y, cv=10 , scoring='accuracy')
k_scores.append(cross_val.mean())
我在每个循环中均取了cross_val_score的平均值并将其绘制出来。
plt.style.use('fivethirtyeight')
plt.plot(k_loop, k_scores)
plt.show()
这是结果。
当k
在14
至20
之间时,可以看到精度更高。
1)如何选择k的最佳值。
2)是否有其他方法可以计算和找到K
的最佳价值?
3)任何其他改进建议也将受到赞赏。我是ML
的新手
答案 0 :(得分:3)
让我们首先定义什么是K
?
K
是算法参考以决定给定数据点属于哪个类的 投票者 的数量< / em>。
换句话说,它使用K
来划分每个类的边界。这些界限将每个类别彼此隔离。
因此,边界随着K
值的增加而变得更加平滑。
从逻辑上讲,如果我们将K
增大为 ,它将最终成为任何类别的所有点,具体取决于 占绝大多数 !但是,这将导致所谓的高偏差(即拟合不足)。
相反,如果我们使K
仅等于 1 ,那么对于 训练样本,误差将始终为零。 / em> 。这是因为最接近任何训练数据点的点本身就是它。尽管如此,我们最终会 过拟合 (即高方差)边界,因此对于任何新的和看不见的数据,它无法概括!。>
不幸的是,没有没有经验法则。 K
的选择在某种程度上受最终应用程序和数据集的驱动。
使用GridSearchCV对估计器的指定参数值进行详尽搜索。因此,我们使用它来尝试找到K
的最佳值。
对于我来说,当我想要设置K
的最大阈值时,我没有超过每个类中元素数量的最大类,并且至今为止我还没有失望( 稍后再看示例,看看我在说什么)
示例:
import numpy as np
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV, RepeatedStratifiedKFold
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
X, y = iris.data, iris.target
# get the max class with respect to the number of elements
max_class = np.max(np.bincount(y))
# you can add other parameters after doing your homework research
# for example, you can add 'algorithm' : ['auto', 'ball_tree', 'kd_tree', 'brute']
grid_param = {'n_neighbors': range(1, max_class)}
model = KNeighborsClassifier()
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=2)
clf = GridSearchCV(model, grid_param, cv=cv, scoring='accuracy')
clf.fit(X, y)
print("Best Estimator: \n{}\n".format(clf.best_estimator_))
print("Best Parameters: \n{}\n".format(clf.best_params_))
print("Best Score: \n{}\n".format(clf.best_score_))
结果
Best Estimator:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=17, p=2,
weights='uniform')
Best Parameters:
{'n_neighbors': 17}
Best Score:
0.98
关于RepeatedStratifiedKFold
简单来说,是KFold
在n_repeats
次上被重复,为什么?因为它可以降低偏差,并为您提供更好的统计估计。
它也是Stratified
的目标,是确保每个测试折叠中的每个类别均近似均等表示(即每个折叠代表全部 strong>数据层)。
答案 1 :(得分:1)
根据图表,我会说13。
我认为这是一项分类工作。
在这种情况下:请勿将k设置为偶数。
例如如果您有2个A类和B类,并且k设置为4。
新数据(或点)有可能
在2类A和2类B之间。
因此,您将有2次投票将新数据点归类为A
和2票被归类为B。
将k设置为奇数可避免这种情况。