sklearn中learning_curve函数中estimator参数的值应该是什么?

时间:2018-07-03 10:29:40

标签: python machine-learning scikit-learn data-science knn

我正在尝试画一条学习曲线,而我想使用的算法是knn算法。为此,估算器的值应该是多少。它可能的值或选项不在文档中(并且我不确定是否应该存在)。

这是我的代码-

features = ['age','sex','cp','trestbps','chol','fbs','restecg','thalach','exang','oldpeak','slope','ca','thal']
target = 'num'

train_size, train_scores, validation_scores = learning_curve(estimator = KNN(), x=dataset[features], y=dataset[target], train_size=train_sizes, cv=5, scoring='confusion_matrix')
错误是-未定义KNN()(这很明显是为什么)。但是我的问题是,如果我想使用knn算法,它应该有什么价值。

1 个答案:

答案 0 :(得分:2)

来自learning curve docs

  

estimator :实现“ fit”和“ predict”方法的对象类型

因此,如果您处于回归设置中,则应使用

from sklearn.neighbors import KNeighborsRegressor
# define the no. of nearest neighbors k
train_size, train_scores, validation_scores = learning_curve(estimator = KNeighborsRegressor(n_neighbors=k), [...])

如果处于分类设置中,则应使用

from sklearn.neighbors import KNeighborsClassifier
# define the no. of nearest neighbors k
train_size, train_scores, validation_scores = learning_curve(estimator = KNeighborsClassifier(n_neighbors=k), [...])

在两种情况下当然都应该定义最近的邻居k的数量。

通常的想法是,在estimator参数中,您可以使用实现s fitpredist方法的任何scikit-learn可用算法,如文档中明确提到的(上面提供的链接) )。