我使用GridSearchCV来调整SVM分类器,然后绘制学习曲线。但是,除非我在绘制学习曲线之前设置了一个新的分类器,否则我会遇到一个IndexError,我不确定为什么会这样。
我的简历/分类器设置如下:
# Set up classifier
clf_untuned = OneVsRestClassifier(SVC(kernel='rbf', random_state=0, max_iter=1000))
cv = cross_validation.ShuffleSplit(data_image.shape[1], n_iter=10,
test_size=0.1, random_state=0)
# Use cross validation / grid search to find optimal hyperparameters
if TRAINING_CROSS_VALIDATION == 1:
params = {
...
}
clf_tuned = GridSearchCV(clf_untuned, cv=cv, param_grid=params)
clf_tuned.fit(x_train, y_train)
print('Best parameters: %s' % clf_tuned.best_params_)
else:
clf_tuned = OneVsRestClassifier(SVC(kernel='rbf',
C=100, gamma=0.00001, random_state=0, verbose=0))
clf_tuned.fit(x_train, y_train)
然后我继续绘制学习曲线,其中plot_learning_curve复制了sklearn示例(http://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html)。如果我使用以下代码,那么我会在“学习_曲线”中收到以下错误: plot_learning_curve中的一行:
# Plot learning curve for best params -- yields IndexError
plot_learning_curve(clf_tuned, title, x_train, y_train, ylim=(0.6, 1.05), cv=cv)
IndexError:索引663超出了70的范围
然而,如果我开始一个新的分类,那么一切正常:
# Plot learning curve for best params -- functions correctly
estimator = OneVsRestClassifier(SVC(kernel='rbf',
C=100, gamma=0.00001, random_state=0, verbose=0))
plot_learning_curve(estimator, title, x_train, y_train, ylim=(0.6, 1.05), cv=cv)
这是为什么?非常感谢,并欢迎对我可疑的实施提出其他意见。
答案 0 :(得分:1)
通过将通过网格搜索获得的最佳估算值传递为clf_tuned.best_estimator _
来解决问题