LightGBM / sklearn管道变换器不能在fit_params上运行[' eval_set']

时间:2017-09-06 15:57:05

标签: scikit-learn pipeline lightgbm

使用GridSearchCV使用early_stopping_rounds或将外部测试集与Pipeline结合使用时,eval_set似乎被fit忽略 管道。 from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin from sklearn.utils.validation import check_array, check_is_fitted from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV from sklearn.model_selection import train_test_split import lightgbm as lgbm import numpy as np class Transformer(BaseEstimator, TransformerMixin): def __init__(self): pass def get_params(self, deep=True): return dict() def fit(self, X, y=None): X = check_array(X, dtype=object) print(X.shape) self.input_shape_ = X.shape return self def set_params(self, **parameters): self.__dict__.update(parameters) return self def transform(self, X): # Check is fit had been called check_is_fitted(self, ['input_shape_']) # Input validation X = check_array(X, dtype=object) Xt = np.zeros((len(X), 1), dtype=np.float32) for i in xrange(Xt.shape[0]): Xt[i] = np.float32(X[i][0].s)**2.0 print(Xt) return Xt class Foo: def __init__(self, s): self.s = s if __name__ == '__main__': x = np.array([Foo(x) for x in xrange(10)]).reshape(-1, 1) y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) x_train, x_test, y_train, y_test = train_test_split(x, y, stratify=y, test_size=0.2, random_state=42) params = {'lgbm__learning_rate': [0.05, 0.1]} """ static_params = {'n_estimators': 100, # 0, } """ static_params = {'n_estimators': 100, # 0, 'early_stopping_rounds': 5, 'eval_metric': 'binary_logloss', 'is_unbalance': False, 'eval_set': [[x_test, y_test]] } pipe = Pipeline(steps=[('transformer', Transformer()), ('lgbm', lgbm.LGBMClassifier(**static_params))]) estimator = GridSearchCV(pipe, scoring='roc_auc', param_grid=params, cv=2, n_jobs=-1) print(x_train) print(y_train) estimator.fit(x_train, y_train) 函数仅应用于训练数据,eval_set数据仅传递给最终估算器,而不在其上运行变换器。

有解决这个问题的好方法吗? 我附上了一个小例子,表明eval_set没有被管道转换。 我已经读过可以以某种方式扩展分类器,但我不确定如何从中访问管道对象。

lvwMachine.ItemsSource.Where(element => element.IsChecked).ToArray();

0 个答案:

没有答案