我正在训练MLP并使用sklearn的0.18dev版本。我不知道我的代码有什么问题。你能帮忙吗?
# TODO: Import 'GridSearchCV' and 'make_scorer'
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import make_scorer
# TODO: Create the parameters list you wish to tune
parameters = {'max_iter' : [100,200]}
# TODO: Initialize the classifier
clf = clf_B
# TODO: Make an f1 scoring function using 'make_scorer'
f1_scorer = make_scorer(f1_score, pos_label = 'Yes')
# TODO: Perform grid search on the classifier using the f1_scorer as the scoring method
grid_obj = GridSearchCV(clf,parameters,scoring = f1_scorer)
# TODO: Fit the grid search object to the training data and find the optimal parameters
grid_obj = grid_obj.fit(X_train, y_train)
# Get the estimator
clf = grid_obj.best_estimator_
# Report the final F1 score for training and testing after parameter tuning
print "Tuned model has a training F1 score of {:.4f}.".format(predict_labels(clf, X_train, y_train))
print "Tuned model has a testing F1 score of {:.4f}.".format(predict_labels(clf, X_test, y_test))
错误消息
---
IndexError Traceback (most recent call last)
<ipython-input-216-4a3fb1d65cb7> in <module>()
24
25 # TODO: Fit the grid search object to the training data and find the optimal parameters
---> 26 grid_obj = grid_obj.fit(X_train, y_train)
27
28 # Get the estimator
/home/indy/anaconda2/lib/python2.7/site-packages/sklearn/grid_search.pyc in fit(self, X, y)
810
811 """
--> 812 return self._fit(X, y, ParameterGrid(self.param_grid))
813
814
/home/indy/anaconda2/lib/python2.7/site-packages/sklearn/grid_search.pyc in _fit(self, X, y, parameter_iterable)
537 'of samples (%i) than data (X: %i samples)'
538 % (len(y), n_samples))
--> 539 cv = check_cv(cv, X, y, classifier=is_classifier(estimator))
540
541 if self.verbose > 0:
/home/indy/anaconda2/lib/python2.7/site-packages/sklearn/cross_validation.pyc in check_cv(cv, X, y, classifier)
1726 if classifier:
1727 if type_of_target(y) in ['binary', 'multiclass']:
-> 1728 cv = StratifiedKFold(y, cv)
1729 else:
1730 cv = KFold(_num_samples(y), cv)
/home/indy/anaconda2/lib/python2.7/site-packages/sklearn/cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state)
546 for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
547 for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 548 label_test_folds = test_folds[y == label]
549 # the test split can be too big because we used
550 # KFold(max(c, self.n_folds), self.n_folds) instead of
IndexError: too many indices for array
MLPClassifier,这就是我输入的样子。
clf_B = MLPClassifier(random_state=4)
print X_train
school_GP school_MS sex_F sex_M age address_R address_U \
171 1.0 0.0 0.0 1.0 16 0.0 1.0
12 1.0 0.0 0.0 1.0 15 0.0 1.0
13 1.0 0.0 0.0 1.0 15 0.0 1.0
151 1.0 0.0 0.0 1.0 16 0.0 1.0
310 1.0 0.0 1.0 0.0 19 0.0 1.0
274 1.0 0.0 1.0 0.0 17 0.0 1.0
371 0.0 1.0 0.0 1.0 18 1.0 0.0
29 1.0 0.0 0.0 1.0 16 0.0 1.0
109 1.0 0.0 1.0 0.0 16 0.0 1.0
327 1.0 0.0 0.0 1.0 17 1.0 0.0
131 1.0 0.0 1.0 0.0 15 0.0 1.0
128 1.0 0.0 0.0 1.0 18 1.0 0.0
174 1.0 0.0 1.0 0.0 16 0.0 1.0
108 1.0 0.0 0.0 1.0 15 1.0 0.0
280 1.0 0.0 0.0 1.0 17 0.0 1.0
163 1.0 0.0 0.0 1.0 17 0.0 1.0
178 1.0 0.0 0.0 1.0 16 1.0 0.0
275 1.0 0.0 1.0 0.0 17 0.0 1.0
35 1.0 0.0 1.0 0.0 15 0.0 1.0
276 1.0 0.0 1.0 0.0 18 1.0 0.0
282 1.0 0.0 1.0 0.0 18 1.0 0.0
99 1.0 0.0 1.0 0.0 16 0.0 1.0
194 1.0 0.0 0.0 1.0 16 0.0 1.0
357 0.0 1.0 1.0 0.0 17 0.0 1.0
10 1.0 0.0 1.0 0.0 15 0.0 1.0
112 1.0 0.0 1.0 0.0 16 0.0 1.0
338 1.0 0.0 1.0 0.0 18 0.0 1.0
292 1.0 0.0 1.0 0.0 18 0.0 1.0
305 1.0 0.0 1.0 0.0 18 0.0 1.0
340 1.0 0.0 1.0 0.0 19 0.0 1.0
.. ... ... ... ... ... ... ...
255 1.0 0.0 0.0 1.0 17 0.0 1.0
58 1.0 0.0 0.0 1.0 15 0.0 1.0
33 1.0 0.0 0.0 1.0 15 0.0 1.0
38 1.0 0.0 1.0 0.0 15 1.0 0.0
359 0.0 1.0 1.0 0.0 18 0.0 1.0
51 1.0 0.0 1.0 0.0 15 0.0 1.0
363 0.0 1.0 1.0 0.0 17 0.0 1.0
260 1.0 0.0 1.0 0.0 18 0.0 1.0
102 1.0 0.0 0.0 1.0 15 0.0 1.0
195 1.0 0.0 1.0 0.0 17 0.0 1.0
167 1.0 0.0 1.0 0.0 16 0.0 1.0
293 1.0 0.0 1.0 0.0 17 1.0 0.0
116 1.0 0.0 0.0 1.0 15 0.0 1.0
124 1.0 0.0 1.0 0.0 16 0.0 1.0
218 1.0 0.0 1.0 0.0 17 0.0 1.0
287 1.0 0.0 1.0 0.0 17 0.0 1.0
319 1.0 0.0 1.0 0.0 18 0.0 1.0
47 1.0 0.0 0.0 1.0 16 0.0 1.0
213 1.0 0.0 0.0 1.0 18 0.0 1.0
389 0.0 1.0 1.0 0.0 18 0.0 1.0
95 1.0 0.0 1.0 0.0 15 1.0 0.0
162 1.0 0.0 0.0 1.0 16 0.0 1.0
263 1.0 0.0 1.0 0.0 17 0.0 1.0
360 0.0 1.0 1.0 0.0 18 1.0 0.0
75 1.0 0.0 0.0 1.0 15 0.0 1.0
299 1.0 0.0 0.0 1.0 18 0.0 1.0
22 1.0 0.0 0.0 1.0 16 0.0 1.0
72 1.0 0.0 1.0 0.0 15 1.0 0.0
15 1.0 0.0 1.0 0.0 16 0.0 1.0
168 1.0 0.0 1.0 0.0 16 0.0 1.0
famsize_GT3 famsize_LE3 Pstatus_A ... higher internet \
171 1.0 0.0 0.0 ... 1 1
12 0.0 1.0 0.0 ... 1 1
13 1.0 0.0 0.0 ... 1 1
151 0.0 1.0 0.0 ... 1 0
310 0.0 1.0 0.0 ... 1 0
274 1.0 0.0 0.0 ... 1 1
371 0.0 1.0 0.0 ... 0 1
29 1.0 0.0 0.0 ... 1 1
109 0.0 1.0 0.0 ... 1 1
327 1.0 0.0 0.0 ... 1 1
131 1.0 0.0 0.0 ... 1 1
128 1.0 0.0 0.0 ... 1 1
174 0.0 1.0 0.0 ... 1 1
108 1.0 0.0 0.0 ... 1 1
280 0.0 1.0 1.0 ... 1 1
163 1.0 0.0 0.0 ... 0 1
178 1.0 0.0 0.0 ... 1 1
275 0.0 1.0 0.0 ... 1 1
35 1.0 0.0 0.0 ... 1 0
276 1.0 0.0 1.0 ... 0 1
282 0.0 1.0 0.0 ... 1 0
99 1.0 0.0 0.0 ... 1 1
194 1.0 0.0 0.0 ... 1 1
357 0.0 1.0 1.0 ... 1 0
10 1.0 0.0 0.0 ... 1 1
112 1.0 0.0 0.0 ... 1 1
338 0.0 1.0 0.0 ... 1 1
292 0.0 1.0 0.0 ... 1 1
305 1.0 0.0 0.0 ... 1 1
340 1.0 0.0 0.0 ... 1 1
.. ... ... ... ... ... ...
255 0.0 1.0 0.0 ... 1 1
58 0.0 1.0 0.0 ... 1 1
33 0.0 1.0 0.0 ... 1 1
38 1.0 0.0 0.0 ... 1 1
359 0.0 1.0 0.0 ... 1 1
51 0.0 1.0 0.0 ... 1 1
363 0.0 1.0 0.0 ... 1 1
260 1.0 0.0 0.0 ... 1 1
102 1.0 0.0 0.0 ... 1 1
195 0.0 1.0 0.0 ... 1 1
167 1.0 0.0 0.0 ... 1 1
293 0.0 1.0 0.0 ... 1 0
116 1.0 0.0 0.0 ... 1 0
124 1.0 0.0 0.0 ... 1 1
218 1.0 0.0 0.0 ... 1 0
287 1.0 0.0 0.0 ... 1 1
319 1.0 0.0 0.0 ... 1 1
47 1.0 0.0 0.0 ... 1 1
213 1.0 0.0 0.0 ... 1 1
389 1.0 0.0 0.0 ... 1 0
95 1.0 0.0 0.0 ... 1 1
162 0.0 1.0 0.0 ... 1 0
263 1.0 0.0 0.0 ... 1 0
360 0.0 1.0 1.0 ... 1 0
75 1.0 0.0 0.0 ... 1 1
299 0.0 1.0 0.0 ... 1 1
22 0.0 1.0 0.0 ... 1 1
72 1.0 0.0 0.0 ... 1 1
15 1.0 0.0 0.0 ... 1 1
168 1.0 0.0 0.0 ... 1 1
romantic famrel freetime goout Dalc Walc health absences
171 1 4 3 2 1 1 3 2
12 0 4 3 3 1 3 5 2
13 0 5 4 3 1 2 3 2
151 1 4 4 4 3 5 5 6
310 1 4 2 4 2 2 3 0
274 1 4 3 3 1 1 1 2
371 1 4 3 3 2 3 3 3
29 1 4 4 5 5 5 5 16
109 1 5 4 5 1 1 4 4
327 0 4 4 5 5 5 4 8
131 1 4 3 3 1 2 4 0
128 0 3 3 3 1 2 4 0
174 0 4 4 5 1 1 4 4
108 1 1 3 5 3 5 1 6
280 1 4 5 4 2 4 5 30
163 0 5 3 3 1 4 2 2
178 1 4 3 3 3 4 3 10
275 1 4 4 4 2 3 5 6
35 0 3 5 1 1 1 5 0
276 1 4 1 1 1 1 5 75
282 0 5 2 2 1 1 3 1
99 0 5 3 5 1 1 3 0
194 0 5 3 3 1 1 3 0
357 1 1 2 3 1 2 5 2
10 0 3 3 3 1 2 2 0
112 0 3 1 2 1 1 5 6
338 0 5 3 3 1 1 1 7
292 1 5 4 3 1 1 5 12
305 0 4 4 3 1 1 3 8
340 1 4 3 4 1 3 3 4
.. ... ... ... ... ... ... ... ...
255 0 4 4 4 1 2 5 2
58 0 4 3 2 1 1 5 2
33 0 5 3 2 1 1 2 0
38 0 4 3 2 1 1 5 2
359 0 5 3 2 1 1 4 0
51 0 4 3 3 1 1 5 2
363 1 2 3 4 1 1 1 0
260 1 3 1 2 1 3 2 21
102 0 5 3 3 1 1 5 4
195 1 4 3 2 1 1 5 0
167 1 4 2 3 1 1 3 0
293 0 3 1 2 1 1 3 6
116 0 4 4 3 1 1 2 2
124 1 5 4 4 1 1 5 0
218 0 3 3 3 1 4 3 3
287 0 4 3 3 1 1 3 6
319 0 4 4 4 3 3 5 2
47 0 4 2 2 1 1 2 4
213 0 4 4 4 2 4 5 15
389 0 1 1 1 1 1 5 0
95 0 3 1 2 1 1 1 2
162 0 4 4 4 2 4 5 0
263 0 3 2 3 1 1 4 4
360 1 4 3 4 1 4 5 0
75 0 4 3 3 2 3 5 6
299 1 1 4 2 2 2 1 5
22 0 4 5 1 1 3 5 2
72 1 3 3 4 2 4 5 2
15 0 4 4 4 1 2 2 4
168 0 5 1 5 1 1 4 0
[300 rows x 48 columns]
这就是我的输出的样子
print y_train
passed
171 yes
12 yes
13 yes
151 yes
310 no
274 yes
371 yes
29 yes
109 yes
327 yes
131 no
128 no
174 no
108 yes
280 no
163 yes
178 no
275 yes
35 no
276 no
282 yes
99 no
194 yes
357 yes
10 no
112 yes
338 yes
292 yes
305 yes
340 yes
.. ...
255 no
58 no
33 yes
38 yes
359 yes
51 yes
363 yes
260 yes
102 yes
195 yes
167 yes
293 yes
116 yes
124 no
218 no
287 yes
319 yes
47 yes
213 no
389 no
95 yes
162 no
263 no
360 yes
75 yes
299 yes
22 yes
72 no
15 yes
168 no
[300 rows x 1 columns]