我已经在大多数sklearn的分类器上成功使用了mlxtend的堆叠分类器,但是当我使用keras分类器sklearn包装器时,我似乎无法使用它。我认为这与神经网络的数据转换方式有关(使用数据集的values属性),但我无法弄清楚如何更改数据以便可以在keras分类器中使用在堆叠分类器中。
这是我的代码:
nn_data = training_data.values
nn = prediction_data.drop(['id', 'era', 'data_type'], axis=1)
nn_prediction = nn.values
x = nn_data[:,3:53]
y = nn_data[:,53]
clf1 = KerasClassifier(build_fn=nn_model, epochs=9, batch_size=2000, verbose=2)
lr = LR.LogisticRegression()
sclf = StackingClassifier(classifiers=[clf1], meta_classifier=lr, use_probas=True)
sclf.fit(x, y)
这是我的错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-15-cb767df93cc9> in <module>()
4 mlp = MLPClassifier()
5 sclf = StackingClassifier(classifiers=[clf1], meta_classifier=lr, use_probas=True)
----> 6 sclf.fit(x, y)
7 y_prediction_sclf = sclf.predict_proba(x_pred)
8 print ('final_model logloss = ' + str(metrics.log_loss(y_pred, y_prediction_sclf)))
/Users/wahabkazi/anaconda/lib/python3.6/site-packages/mlxtend/classifier/stacking_classification.py in fit(self, X, y)
118
119 if not self.use_features_in_secondary:
--> 120 self.meta_clf_.fit(meta_features, y)
121 else:
122 self.meta_clf_.fit(np.hstack((X, meta_features)), y)
/Users/wahabkazi/anaconda/lib/python3.6/site-packages/sklearn/linear_model/logistic.py in fit(self, X, y, sample_weight)
1215 X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype,
1216 order="C")
-> 1217 check_classification_targets(y)
1218 self.classes_ = np.unique(y)
1219 n_samples, n_features = X.shape
/Users/wahabkazi/anaconda/lib/python3.6/site-packages/sklearn/utils/multiclass.py in check_classification_targets(y)
170 if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
171 'multilabel-indicator', 'multilabel-sequences']:
--> 172 raise ValueError("Unknown label type: %r" % y_type)
173
174
ValueError: Unknown label type: 'unknown'
以下是x的示例值: 数组([0.49282,0.58077,0.48948,0.56762,0.56107,0.51168, 0.47458999999999996,0.56968,0.47402,0.40326,0.54119, 0.5319699999999999,0.31899,0.43153,0.35538000000000003, 0.6613100000000001,0.42477,0.6548484,0.499437,0.6126699999999999, 0.60285,0.38813000000000003,0.49818999999999997,0.59332,0.63041, 0.40815,0.47767,0.4869,0.51394,0.5371600000000001, 0.49223999999999996,0.44978,0.49446999999999997, 0.46531999999999996,0.51057,0.52177,0.524243,0.61623,0.56988, 0.66293,0.50138,0.40333,0.52337,0.60795,0.35748,0.49677, 0.28295,0.65342,0.57915,0.51136],dtype = object)
提前感谢您的帮助!