sklearn中的自定义变压器

时间:2019-02-15 02:42:57

标签: python machine-learning scikit-learn data-science transformer

我正在sklearn中构建一个变压器,该变压器会丢弃相关系数低于指定阈值的特征。

它适用于训练集。但是,当我转换测试集时。测试仪上的所有功能均消失。我假设变压器正在计算测试数据和训练标签之间的相关性,并且由于这些相关性都很低,因此它将删除所有功能。如何使其仅在训练集上计算相关性,并在变换中从测试集中删除那些特征?

class CorrelatedFeatures(BaseEstimator, TransformerMixin): #Selects only features that have a correlation coefficient higher than threshold with the response label
    def __init__(self, response, threshold=0.1):
        self.threshold = threshold
        self.response = response
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        df = pd.concat([X, self.response], axis=1)
        cols = df.columns[abs(df.corr()[df.columns[-1]]) > self.threshold].drop(self.response.columns)
        return X[cols]

1 个答案:

答案 0 :(得分:0)

您可以计算并存储该相关性并将要删除的列存储在fit()中,而在transform()中只需转换这些列即可。

类似这样的东西:

....
....

def fit(self, X, y=None):
    df = pd.concat([X, self.response], axis=1)
    self.cols = df.columns[abs(df.corr()[df.columns[-1]]) > self.threshold].drop(self.response.columns)
    return self
def transform(self, X, y=None):
    return X[self.cols]