sklearn:有一个过滤样本的估算器

时间:2014-07-22 19:31:35

标签: python scikit-learn

我正在尝试实现自己的Imputer。在某些条件下,我想过滤一些列车样本(我认为质量低)。

但是,由于transform方法仅返回X而不返回y,而y本身只是一个numpy数组(我无法将其过滤到据我所知,此外 - 当我使用GridSearchCV时,y我的transform方法收到None,我似乎无法找到办法它。

只是为了澄清:我非常清楚如何过滤数组。我找不到一种方法来将y向量上的样本过滤纳入当前API。

我真的想从BaseEstimator实现中执行此操作,以便我可以将其与GridSearchCV一起使用(它有一些参数)。我错过了实现样本过滤的不同方法(不是通过BaseEstimator,而是GridSearchCV兼容)?当前的API有什么办法吗?

2 个答案:

答案 0 :(得分:8)

我找到了一个解决方案,它有三个部分:

  1. 拥有if idx == id(self.X):行。这将确保仅在训练集上过滤样本。
  2. 覆盖fit_transform以确保转化方法获得y而非None
  3. 重写Pipeline以允许tranform返回y
  4. 以下是一个展示它的示例代码,我想它可能无法涵盖所有​​细节,但我认为它解决了API的主要问题。

    from sklearn.base import BaseEstimator
    from mne.decoding.mixin import TransformerMixin
    import numpy as np
    from sklearn.pipeline import Pipeline
    from sklearn.naive_bayes import GaussianNB
    from sklearn import cross_validation
    from sklearn.grid_search import GridSearchCV
    from sklearn.externals import six
    
    class SampleAndFeatureFilter(BaseEstimator, TransformerMixin):
        def __init__(self, perc = None):
            self.perc = perc
    
        def fit(self, X, y=None):
            self.X = X
            sum_per_feature = X.sum(0)
            sum_per_sample = X.sum(1)
            self.featurefilter = sum_per_feature >= np.percentile(sum_per_feature, self.perc)
            self.samplefilter  = sum_per_sample >= np.percentile(sum_per_sample, self.perc)
            return self
    
        def transform(self, X, y=None, copy=None):
            idx = id(X)
            X=X[:,self.featurefilter]
            if idx == id(self.X):
                X = X[self.samplefilter, :]
                if y is not None:
                    y = y[self.samplefilter]
                return X, y
            return X
    
        def fit_transform(self, X, y=None, **fit_params):
            if y is None:
                return self.fit(X, **fit_params).transform(X)
            else:
                return self.fit(X, y, **fit_params).transform(X,y)
    
    class PipelineWithSampleFiltering(Pipeline):
        def fit_transform(self, X, y=None, **fit_params):
            Xt, yt, fit_params = self._pre_transform(X, y, **fit_params)
            if hasattr(self.steps[-1][-1], 'fit_transform'):
                return self.steps[-1][-1].fit_transform(Xt, yt, **fit_params)
            else:
                return self.steps[-1][-1].fit(Xt, yt, **fit_params).transform(Xt)
    
        def fit(self, X, y=None, **fit_params):
            Xt, yt, fit_params = self._pre_transform(X, y, **fit_params)
            self.steps[-1][-1].fit(Xt, yt, **fit_params)
            return self
    
        def _pre_transform(self, X, y=None, **fit_params):
            fit_params_steps = dict((step, {}) for step, _ in self.steps)
            for pname, pval in six.iteritems(fit_params):
                step, param = pname.split('__', 1)
                fit_params_steps[step][param] = pval
            Xt = X
            yt = y
            for name, transform in self.steps[:-1]:
                if hasattr(transform, "fit_transform"):
                    res = transform.fit_transform(Xt, yt, **fit_params_steps[name])
                    if len(res) == 2:
                        Xt, yt = res
                    else:
                        Xt = res
                else:
                    Xt = transform.fit(Xt, y, **fit_params_steps[name]) \
                                  .transform(Xt)
            return Xt, yt, fit_params_steps[self.steps[-1][0]]
    
    if __name__ == '__main__':
        X = np.random.random((100,30))
        y = np.random.random_integers(0, 1, 100)
        pipe = PipelineWithSampleFiltering([('flt', SampleAndFeatureFilter()), ('cls', GaussianNB())])
        X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size = 0.3, random_state = 42)
        kfold = cross_validation.KFold(len(y_train), 10)
        clf = GridSearchCV(pipe, cv = kfold, param_grid = {'flt__perc':[10,20,30,40,50,60,70,80]}, n_jobs = 1)
        clf.fit(X_train, y_train)
    

答案 1 :(得分:3)

scikit-learn变换器API用于更改数据的特征(在本质上,可能在数量/维度上),但不用于更改样本数量。从scikit-learn的现有版本开始,任何丢弃或添加样本的变换器都不符合API(如果认为重要,可能是未来的添加)。

因此,考虑到这一点,您似乎必须按照标准的scikit-learn API工作。