在SciKitLearn管道中同时使用文本和数字功能

时间:2019-03-13 20:05:04

标签: python machine-learning scikit-learn

我正在研究文本分类模型。我已经处理了文本,并且可以正常使用管道。然后,我想将情绪评分添加到预测模型中。

这是正常运行的管道。 X是文本块,y是文本类别。

# logistic regression pipeline
log_res_clf = Pipeline([('vect', CountVectorizer(stop_words=en_stops)),
    ('tfidf', TfidfTransformer(use_idf=True)),
    ('clf', lm.LogisticRegression(solver='saga',multi_class='multinomial', random_state=42)),])

我尝试将其他预测变量简单地添加到X数据框中,但是出现以下错误:

ValueError: Found input variables with inconsistent numbers of samples: [3, 23322]

因此,我进行了更多研究,发现了一个示例管道,该管道连接了似乎对其他人有用的功能:

# build Log Res Pipeline
get_numeric_data = FunctionTransformer(lambda x: x[['polarity', 'subjectivity']], validate=False)

log_res_clf = Pipeline([
    ('features', FeatureUnion([
            ('numeric_features', Pipeline([
                ('selector', get_numeric_data)
            ])),
             ('word_features', Pipeline([
                ('vect', CountVectorizer(stop_words=en_stops)), 
                ('tfidf', TfidfTransformer(use_idf = True)),
            ])),
         ])),
           ('clf', lm.LogisticRegression(solver='saga',multi_class='multinomial', random_state=42)) 
     ])

这给我的数据一个形状相关的错误:

ValueError: blocks[0,:] has incompatible row dimensions. Got blocks[0,1].shape[0] == 3, expected 23322.

我从以下可接受的答案中获得了该管道:how to featureUnion numerical and text features in python sklearn properly

这是完整的追溯:

<ipython-input-119-7391cab6b5c3> in <module>
----> 1 log_res_clf.fit(X_train, y_train)
      2 log_res_preds = log_res_clf.predict(X_test)
      3 log_res_probs = log_res_clf.predict_proba(X_test)
      4 
      5 log_res_topn = np.argsort(log_res_probs, axis = 1)[:,-n:]

C:\Anaconda3\lib\site-packages\sklearn\pipeline.py in fit(self, X, y, **fit_params)
    263             This estimator
    264         """
--> 265         Xt, fit_params = self._fit(X, y, **fit_params)
    266         if self._final_estimator is not None:
    267             self._final_estimator.fit(Xt, y, **fit_params)

C:\Anaconda3\lib\site-packages\sklearn\pipeline.py in _fit(self, X, y, **fit_params)
    228                 Xt, fitted_transformer = fit_transform_one_cached(
    229                     cloned_transformer, Xt, y, None,
--> 230                     **fit_params_steps[name])
    231                 # Replace the transformer of the step with the fitted
    232                 # transformer. This is necessary when loading the transformer

C:\Anaconda3\lib\site-packages\sklearn\externals\joblib\memory.py in __call__(self, *args, **kwargs)
    340 
    341     def __call__(self, *args, **kwargs):
--> 342         return self.func(*args, **kwargs)
    343 
    344     def call_and_shelve(self, *args, **kwargs):

C:\Anaconda3\lib\site-packages\sklearn\pipeline.py in _fit_transform_one(transformer, X, y, weight, **fit_params)
    612 def _fit_transform_one(transformer, X, y, weight, **fit_params):
    613     if hasattr(transformer, 'fit_transform'):
--> 614         res = transformer.fit_transform(X, y, **fit_params)
    615     else:
    616         res = transformer.fit(X, y, **fit_params).transform(X)

C:\Anaconda3\lib\site-packages\sklearn\pipeline.py in fit_transform(self, X, y, **fit_params)
    799         self._update_transformer_list(transformers)
    800         if any(sparse.issparse(f) for f in Xs):
--> 801             Xs = sparse.hstack(Xs).tocsr()
    802         else:
    803             Xs = np.hstack(Xs)

C:\Anaconda3\lib\site-packages\scipy\sparse\construct.py in hstack(blocks, format, dtype)
    463 
    464     """
--> 465     return bmat([blocks], format=format, dtype=dtype)
    466 
    467 

C:\Anaconda3\lib\site-packages\scipy\sparse\construct.py in bmat(blocks, format, dtype)
    584                                                     exp=brow_lengths[i],
    585                                                     got=A.shape[0]))
--> 586                     raise ValueError(msg)
    587 
    588                 if bcol_lengths[j] == 0:

我的输入是:

X = df[['cleaned','polarity','subjectivity']] (text block, numeric, numeric)
y = df['category_id'] (encoded class)

0 个答案:

没有答案