sci-kit得知TransformerMixin具有奇怪的fit_transform行为

时间:2019-05-03 10:15:30

标签: python scikit-learn

fit_transform继承的TransformerMixin方法不会将y变量传递给transform方法。  这是一个最小的例子

class UslessPrint(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass
    def fit(self, X, y=None):
        print(y)
        return self
    def transform(self, X, y=None):
        print(y)
        return X

usless = mc.UslessPrint()
usless.fit_transform([[1, 2], [2, 1]], [0, 1])

输出为:

[0, 1]
None

而不是预期的

[0, 1]
[0, 1]

这是正常现象,还是错误?

1 个答案:

答案 0 :(得分:1)

这是TransformerMixin的预期行为,不是错误。

sklearn.Base.TransformerMixin的{​​{1}}在其fit_transform中未使用y

来自the latest version of sklearn from github

transform

如您所见,class TransformerMixin(object): """Mixin class for all transformers in scikit-learn.""" def fit_transform(self, X, y=None, **fit_params): """Fit to data, then transform it. Fits transformer to X and y with optional parameters fit_params and returns a transformed version of X. Parameters ---------- X : numpy array of shape [n_samples, n_features] Training set. y : numpy array of shape [n_samples] Target values. Returns ------- X_new : numpy array of shape [n_samples, n_features_new] Transformed array. """ # non-optimized default implementation; override when a better # method is possible for a given clustering algorithm if y is None: # fit method of arity 1 (unsupervised transformation) return self.fit(X, **fit_params).transform(X) else: # fit method of arity 2 (supervised transformation) return self.fit(X, y, **fit_params).transform(X) 仅将TransformerMixin传递给X,在代码中将transform保留为y