scikit-learn:如何用管道构建LabelEncoder和OneHotEncoder?

时间:2018-02-22 13:51:43

标签: python scikit-learn one-hot-encoding

在预处理机器学习分类任务的标签时,我需要对带有字符串值的标签进行热编码。来自OneHotEncoder的{​​{1}}或来自sklearn.preprocessing的{​​{1}}需要to_categorical输入。这意味着我需要在一个热编码器之前加上kera.np_utils。我已经手动完成了自定义类:

int

我确信必须使用LabelEncoder在sklearn API中执行此操作,但在使用时:

class LabelOneHotEncoder():
    def __init__(self):
        self.ohe = OneHotEncoder()
        self.le = LabelEncoder()
    def fit_transform(self, x):
        features = self.le.fit_transform( x)
        return self.ohe.fit_transform( features.reshape(-1,1))
    def transform( self, x):
        return self.ohe.transform( self.la.transform( x.reshape(-1,1)))
    def inverse_tranform( self, x):
        return self.le.inverse_transform( self.ohe.inverse_tranform( x))
    def inverse_labels( self, x):
        return self.le.inverse_transform( x)

我从sklearn.pipeline收到错误LabelOneHotEncoder = Pipeline( [ ("le",LabelEncoder), ("ohe", OneHotEncoder)]) 。我的猜测是ValueError: bad input shape ()的输出需要通过添加一个普通的第二轴来重新整形。我不知道如何添加此功能。

3 个答案:

答案 0 :(得分:12)

很奇怪,他们不能很好地一起玩......我很惊讶。我会扩展该类以返回您建议的重新整形数据。

.section .text
.global _start

_start:
    mov x0, #0  // exit with status 0
    mov x8, #93 // svc argument goes in x8, and the argument for 'exit' is 93
    svc #0      // executes a syscall in arm64

然后使用管道应该可以工作。

class ModifiedLabelEncoder(LabelEncoder):

    def fit_transform(self, y, *args, **kwargs):
        return super().fit_transform(y).reshape(-1, 1)

    def transform(self, y, *args, **kwargs):
        return super().transform(y).reshape(-1, 1)

https://github.com/scikit-learn/scikit-learn/blob/a24c8b46/sklearn/preprocessing/label.py#L39

答案 1 :(得分:8)

从scikit-learn 0.20开始,OneHotEncoder接受字符串,因此您不再需要LabelEncoder。您可以在管道中使用它。

答案 2 :(得分:0)

我使用了一个定制的类来包装我的标签编码器函数,它返回整个更新的数据集。

 class CustomLabelEncode(BaseEstimator, TransformerMixin):
  def fit(self, X, y=None):
   return self
  def transform(self, X ,y=None):
    le=LabelEncoder()
    for i in X[cat_cols]:
    X[i]=le.fit_transform(X[i])
    return X 
cat_cols=['Family','Education','Securities Account','CDAccount','Online','CreditCard']
le_ct=make_column_transformer((CustomLabelEncode(),cat_cols),remainder='passthrough') 
pd.DataFrame(ct3.fit_transform(X)) #This will show you your changes
Final_pipeline=make_pipeline(le_ct)

[我已经实现了你可以看到我的github链接] [1]:https://github.com/Ayushmina-20/sklearn_pipeline