网格搜索

时间:2018-10-20 17:44:20

标签: python-3.x scikit-learn knn

我正在使用python实现KNN,它正在工作。

现在我得到一个错误:

  

没有名为“ sklearn.grid_search”的模块

当我将包更改为sklean.model_selection时,我得到另一个错误:

  

'GridSearchCV'对象没有属性'grid_scores _'

这是我的代码:

from sklearn.grid_search import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
# define the parameter values that should be searched
# for python 2, k_range = range(1, 31)
# instantiate model
knn = KNeighborsClassifier(n_jobs=-1)
k_range = list(range(1, 31))
print(k_range)
# create a parameter grid: map the parameter names to the values that should be searched
# simply a python dictionary
# key: parameter name
# value: list of values that should be searched for that parameter
# single key-value pair for param_grid
param_grid = dict(n_neighbors=k_range)
print(param_grid)
# instantiate the grid
grid = GridSearchCV(knn, param_grid, cv=10, scoring='accuracy')
# fit the grid with data
grid.fit(X, y)
# view the complete results (list of named tuples)
grid.grid_scores_
# examine the first tuple
# we will slice the list and select its elements using dot notation and []


print('Parameters')
print(grid.grid_scores_[0].parameters)

# Array of 10 accuracy scores during 10-fold cv using the parameters
print('')
print('CV Validation Score')
print(grid.grid_scores_[0].cv_validation_scores)

# Mean of the 10 scores
print('')
print('Mean Validation Score')
print(grid.grid_scores_[0].mean_validation_score)
# create a list of the mean scores only
# list comprehension to loop through grid.grid_scores
grid_mean_scores = [result.mean_validation_score for result in grid.grid_scores_]
print(grid_mean_scores)
# plot the results
# this is identical to the one we generated above
plt.plot(k_range, grid_mean_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
# examine the best model

# Single best score achieved across all params (k)
print(grid.best_score_)

# Dictionary containing the parameters (k) used to generate that score
print(grid.best_params_)

# Actual model object fit with those best parameters
# Shows default parameters that 

我们未指定:

print(grid.best_estimator_)

1 个答案:

答案 0 :(得分:-1)

尝试以下方法:

from sklearn.model_selection import GridSearchCV

参考链接 https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_digits.html