进行深度学习时出现grid_search错误

时间:2018-12-31 09:26:36

标签: python keras deep-learning

我正在使用python 3.6进行深度学习。我正在尝试进行网格搜索。我的代码是

def build_Class():
  model=keras.models.Sequential()
  model.add(Dropout(0.5, input_shape=(136,)))
  model.add(keras.layers.Dense(10,init='uniform',activation='sigmoid'))
  model.add(keras.layers.Dense(10,init='uniform',activation='relu'))
  model.add(keras.layers.Dense(1,init='uniform',activation='sigmoid'))
  model.compile(loss=rmsle,optimizer=optimizer,metrics=['accuracy']) # Need to check with 10 K
  return model

from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
classifier=KerasClassifier(build_fn=build_Class)
parameter={'batch_size':[10,25,32],'nb_epoch':[100,500],
          'optimizer':['adam','rmsprop']}
grid_search=GridSearchCV(estimator=classifier,param_grid=parameter,scoring='accuracy',cv=10)

grid_search=grid_search.fit(x,y) 

我遇到以下错误

ValueError                                Traceback (most recent call last)
<ipython-input-36-5a9be9c02d0d> in <module>()
----> 1 grid_search=grid_search.fit(x,y)

/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    720                 return results_container[0]
    721 
--> 722             self._run_search(evaluate_candidates)
    723 
    724         results = results_container[0]

/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_search.py in _run_search(self, evaluate_candidates)
   1189     def _run_search(self, evaluate_candidates):
   1190         """Search all candidates in param_grid"""
-> 1191         evaluate_candidates(ParameterGrid(self.param_grid))
   1192 
   1193 

/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_search.py in evaluate_candidates(candidate_params)
    709                                for parameters, (train, test)
    710                                in product(candidate_params,
--> 711                                           cv.split(X, y, groups)))
    712 
    713                 all_candidate_params.extend(candidate_params)

我的样本x和y看起来像

x

> 361   0.000000    0.000000    0.007459    0.000000    0.00000 0.000000    0.014979    0.00000 0.000000    0.000000    ... 0.071562    0.00000 0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000
> 4042  0.000078    0.000000    0.000000    0.000000    0.00000 0.000000    0.000000    0.00000 0.000000    0.000000    ... 0.016101    0.00000 0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000
> 3353  0.129063    0.000000    0.178565    0.000000    0.40000 0.000000    0.000000    0.00000 0.000000    0.202589    ... 0.510517    0.00000 0.178565    0.000000    0.178565    0.000000    0.202589    0.000000    0.000000    0.178565
> 1947  0.000000    0.000000    0.000000    0.000000    0.00000 0.000000    0.000000    0.02981 0.000000    0.000000    ... 0.112531    0.00000 0.000000    0.000000    0.000000    0.000000    0.000000    0.011040    0.000000    0.000000
> 779   0.000000    0.000000    0.006258    0.000000    0.00000 0.000000    0.000000    0.00000 0.000000    0.000000    ... 0.020574    0.00000 0.007388    0.005000    0.014518    0.000000    0.031288    0.000000    0.000000    0.006258
> 3439  0.000000    0.001797    0.000000    0.000000    0.00000 0.000000    0.000000    0.00000 0.000000    0.000000    ... 0.000000    0.00000 0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.003129
> 890   0.000000    0.000000    0.000000    0.000000    0.00000 0.150000    0.008268    0.00000 0.000000    0.000626    ... 0.000000    0.00000 0.006258    0.006667    0.000626    0.000000    0.014080    0.000000    0.000000    0.000751
> 1270  0.000000    0.000000    0.000000    0.000000    0.00000 0.000000    0.000000    0.00000 0.000000    0.000000    ... 0.053671    0.00000 0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000    0.000000
> 796   0.000000    0.000000    0.000219    0.000000    0.00000 0.000000    0.000000    0.00000 0.000000    0.000000    ... 0.000000    0.00000 0.000219    0.000000    0.000219    0.000000    0.000000    0.000000    0.000000    0.000219
> 4443  0.007361    0.002239    0.001721    0.000000    0.02940 0.164667    0.004414    0.00334 0.001050    0.003650    ... 0.013217    0.01068 0.004892    0.000000    0.004892    0.000000    0.003650    0.019708    0.000418    0.003176

y

> target 211    0.006755 1790   0.001001 3364   0.099324 1811   0.000250
> 4377  0.020624

您能建议我为消除此错误而需要做哪些更改?

0 个答案:

没有答案