SKlearn管道中的自定义变压器未在transform()方法中接收n_samples

时间:2019-05-01 19:28:01

标签: scikit-learn pipeline gridsearchcv

对于此自定义转换器:

class Processor(TransformerMixin):
index = None

def __init__(self, index):
    self.index = index

def fit(self, X, y=None):
    return self

def transform(self, X):
    print(X)
    print('___')
    return X[0][self.index]

和该管道:

    GridSearchCV(
        Pipeline([
            ('extractor', Processor(index=0)),
            ('selector', SelectKBest(score_func=f_classif)),
            ('clf', None)
        ]),
        cv=KFold(n_splits=2, random_state=0),
        return_train_score=False,
        scoring=['explained_variance', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_median_absolute_error','r2'],
        refit='r2',
        param_grid=[{
            # 'selector__k': ['all'],
            'clf__normalize': [True, False],
            'clf__fit_intercept': [True, False],
            'clf': [LinearRegression()]
        }],
        n_jobs=-1
    )

打印输出如下:

[[17.05214286  0.69666262]]
___

即使我的训练功能如下:

[[17.05214286  0.69666262]
 [88.36863636  0.69692261]]

有人可以告诉我为什么我的自定义转换器未收到长度为n_samples的矩阵吗?

0 个答案:

没有答案