在skit中为管道实现转换器时,对象不可迭代

时间:2018-03-20 21:47:49

标签: python machine-learning scikit-learn

我有一个我要分类的字符串列表。我正在使用管道对象。

我实现了两个虚拟变换器:一个将数据转换为特定格式(被另一个变换器接受),另一个将数据再次转换为原始形式(一种反转)。

X和y是字符串列表,假设X=['London is great', 'London is beautiful', 'I hate London']y=['p','p','n']。我希望将X转换为字符串列表列表:X=[['London is great'], ['London is beautiful'], ['I hate London']]

我的代码如下:

from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.pipeline import Pipeline
from sklearn.base import TransformerMixin, BaseEstimator


vectorizer = CountVectorizer(input=u'content',
                             analyzer=u'word',
                             lowercase=True,
                             stop_words=cached_stopwords,
                             strip_accents=u'unicode',
                             ngram_range=(1, 3), binary=False)

estimators = [('pre_ds', PreprocessPreDS()),
              ('post_ds', PreprocesarPostDS()),
              ('vectorizer', vectorizer),
              ('feature_selector', SelectKBest(chi2, k=100)),
              ('clf', MultinomialNB())]  
# create the pipeline
pipe = Pipeline(estimators)
pipe.fit(X_train, y_train)

我的服装变形金属如下:

class PreprocessPreDS(BaseEstimator, TransformerMixin):

    def __init__(self):
        pass

    def transform(self, X, *_):
        return [[x] for x in X]

    def fit(self, *_):
        return self

    def fit_transform(self, X, y=None, **fit_params):
        return self.fit(X)

    def get_params(self, deep=True):
        """
        :param deep: ignored, as suggested by scikit learn's documentation
        :return: dict containing each parameter from the model as name and its current value
        """
        return {}

    def set_params(self, **parameters):
        """
        set all parameters for current objects
        :param parameters: dict containing its keys and values to be initialised
        :return: self
        """
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self


class PreprocesarPostDS(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def transform(self, X, *_):
        return [x[0] for x in X]

    def fit(self, *_):
        return self

    def fit_transform(self, X, y=None, **fit_params):
        return self.fit(X)

    def get_params(self, deep=True):
        """
        :param deep: ignored, as suggested by scikit learn's documentation
        :return: dict containing each parameter from the model as name and its current value
        """
        return {}

    def set_params(self, **parameters):
        """
        set all parameters for current objects
        :param parameters: dict containing its keys and values to be initialised
        :return: self
        """
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

当我运行此代码时,出现以下错误:

    Traceback (most recent call last):
  File "/home/rodrigo/nb/train_nb_pipeline.py", line 449, in <module>
    process(args.label, args.evaluate, args.label_all, corpus=args.corpus_path)
  File "/home/rodrigo/nb/train_nb_pipeline.py", line 179, in process
    pipe.fit(X_train, y_train)
  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/pipeline.py", line 248, in fit
    Xt, fit_params = self._fit(X, y, **fit_params)
  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/pipeline.py", line 213, in _fit
    **fit_params_steps[name])
  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/externals/joblib/memory.py", line 362, in __call__
    return self.func(*args, **kwargs)
  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/pipeline.py", line 581, in _fit_transform_one
    res = transformer.fit_transform(X, y, **fit_params)
  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/feature_extraction/text.py", line 869, in fit_transform
    self.fixed_vocabulary_)
  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/feature_extraction/text.py", line 790, in _count_vocab
    for doc in raw_documents:
TypeError: 'PreprocessPostDS' object is not iterable

但是,如果我从('pre_ds', PreprocessPreDS())中排除('post_ds', PreprocesarPostDS())estimators,则会运行所有权限。

1 个答案:

答案 0 :(得分:1)

改变这个:

def fit_transform(self, X, y=None, **fit_params):
    return self.fit(X)

为:

def fit_transform(self, X, y=None, **fit_params):
    return self.fit(X).transform(X)

在上面的代码中,您实际上是返回selfself是类对象(在本例中为PreprocessPreDS和PreprocessPostDS)。 fit_transform()应该返回转换后的数据,而不是类对象。