在sklearn中编写自定义变换器,它返回.transform中估计器的.predict

时间:2018-03-07 21:01:40

标签: python scikit-learn sklearn-pandas

我们有自定义变压器

class EstimatorTransformer(base.BaseEstimator, base.TransformerMixin):

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y):
        self = self.estimator.fit(X,y)
        return self

    def transform(self, X):
        return self.estimator.predict(X)

还有一个断言声明

city_trans = EstimatorTransformer(city_est)
city_trans.fit(features,target)
assert ([r[0] for r in city_trans.transform(data[:5])]
        == city_est.predict(data[:5]))

其中

  

city_est是我们可以通过的估算工具。我正在使用city_est = city_est = Ridge(alpha = 1)

但我在self = self.estimator.fit(X,y)收到错误。我在这里可能做错了什么。我知道fit()会返回self。我应该如何使这个断言起作用?

1 个答案:

答案 0 :(得分:0)

你在这一行做错了:

self = self.estimator.fit(X,y)

这里,self是当前的类(EstimatorTransformer),你试图为它分配一个不同的类。

你可以写:

def fit(self, X, y):
    self.estimator.fit(X,y)
    return self

它会起作用。