在cross_val_predict之后对新文档进行分类

时间:2017-01-23 22:06:44

标签: python twitter machine-learning scikit-learn classification

我有大约10,000条推文的样本,我想将其分类为“相关”和“不相关”类别。我正在使用Python的scikit-learn来学习这个模型。我手动将1,000条推文编码为“相关”或“不相关”。然后,我使用80%的手动编码数据作为训练数据运行SVM模型,其余作为测试数据。我获得了良好的结果(预测准确度~0.90),但为了避免过度拟合,我决定对所有1,000个手动编码的推文使用交叉验证。

在我的示例中已经获得推文的tf-idf矩阵后,下面是我的代码。 “target”是一个数组,列出推文是标记为“相关”还是“不相关”。

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict

clf = SGDClassifier()
scores = cross_val_score(clf, X_tfidf, target, cv=10)
predicted = cross_val_predict(clf, X_tfidf, target, cv=10)

通过这段代码,我能够预测出1000条推文属于哪些类,我可以将其与手动编码进行比较。

为了使用我的模型对其他约9,000条我没有手动编码的推文进行分类,我仍然坚持下一步该做什么。我想再次使用cross_val_predict,但我不确定在第三个参数中放什么,因为这个类是我想要预测的。

提前感谢您的所有帮助!

1 个答案:

答案 0 :(得分:4)

cross_val_predict是方法,可以从模型中实际获取预测。交叉验证是一种模型选择/评估技术,不是训练模型。 cross_val_predict是一个非常具体的函数(它为您提供了许多模型的预测,在交叉验证过程中进行了训练)。对于实际的模型构建,yu应该使用 fit 来训练你的模型,并且预测来获得预测。此处不涉及交叉验证 - 如前所述 - 这是用于模型选择(选择分类器,超级参数等)而不是训练实际模型。