Scikit学习:使用管道的GridSearchCV。变压器只学习训练数据吗?

时间:2019-07-15 12:48:59

标签: machine-learning scikit-learn

我正在尝试将GridSearchCV与管道结合起来。代码如下:

pipe = Pipeline(steps=[('preprocessing', preprocess), 
                       ('feature_selection', feature_selector), 
                       ('sampling', sampler), 
                       ('model', RandomForestClassifier())])
search = GridSearchCV(pipe, hparam_grid[model_name], 
                      cv=cross_validator, 
                      n_jobs=-1, scoring=scoring, refit='balanced_accuracy')

这里preprocess是来自TransformerMixin的自定义转换器,它具有fittransform方法。您可以想到它具有自定义的StandardScaler方法。

我的问题如下:当我调用search.fit(X, y)时,管道是否将变压器适合训练数据并将转换应用于测试数据?还是仅使用整个数据集(X)?我尝试从自定义转换器的X函数中打印数组fit的形状,它是原始数据集的形状。但是对我来说,这样做很奇怪。

0 个答案:

没有答案