这是我第一次在这里发帖。在过去的几天里,我一直在努力教自己scikit-learn。但是最近我遇到了一个错误,这个错误一直困扰着我。
我的目标只是训练NB分类器cli
,以便我可以为它提供一个名为new_doc
的任意字符串列表,它将预测字符串可能属于哪个类。
这就是我的程序:
#Importing stuff
import numpy as np
import pylab
import pandas as pd
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer, HashingVectorizer, CountVectorizer
from sklearn import metrics
#Opening the csv file
df = pd.read_csv('data.csv', sep=',')
#Randomising the rows in the file
df = df.reindex(np.random.permutation(df.index))
#Extracting features from text, define target y and data X
vect = CountVectorizer()
X = vect.fit_transform(df['Features'])
y = df['Target']
#Partitioning the data into test and training set
SPLIT_PERC = 0.75
split_size = int(len(y)*SPLIT_PERC)
X_train = X[:split_size]
X_test = X[split_size:]
y_train = y[:split_size]
y_test = y[split_size:]
#Training the model
clf = MultinomialNB()
clf.fit(X_train, y_train)
#Evaluating the results
print "Accuracy on training set:"
print clf.score(X_train, y_train)
print "Accuracy on testing set:"
print clf.score(X_test, y_test)
y_pred = clf.predict(X_test)
print "Classification Report:"
print metrics.classification_report(y_test, y_pred)
#Predicting new data
new_doc = ["MacDonalds", "Walmart", "Target", "Starbucks"]
trans_doc = vect.transform(new_doc) #extracting features
y_pred = clf.predict(trans_doc) #predicting
但是当我运行程序时,我在最后一行得到以下错误:
y_pred = clf.predict(trans_doc)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Library/Python/2.7/site-packages/sklearn/naive_bayes.py", line 62, in predict
jll = self._joint_log_likelihood(X)
File "/Library/Python/2.7/site-packages/sklearn/naive_bayes.py", line 441, in _joint_log_likelihood
return (safe_sparse_dot(X, self.feature_log_prob_.T)
File "/Library/Python/2.7/site-packages/sklearn/utils/extmath.py", line 175, in safe_sparse_dot
ret = a * b
File "/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/scipy/sparse/base.py", line 334, in __mul__
raise ValueError('dimension mismatch')
ValueError: dimension mismatch
显然,它与术语 - 文档矩阵的维度有关。
当我检查trans_doc,X_train和X_test的尺寸时,我得到:
>>> trans_doc.shape
(4, 4)
>>> X_train.shape
(145314, 28750)
>>> X_test.shape
(48439, 28750)
为了y_pred = clf.predict(trans_doc)
能够工作,我需要(根据我的理解)将new_doc
转换为维度为(4, 28750)
的术语 - 文档矩阵。但我不知道CountVectorizer
中允许我这样做的任何方法。