我有一个数据集,其中只包含两个用于训练我的模型的有用列,第一个是新闻标题,第二个是新闻类别。
因此,我使用python成功运行了以下训练命令:
import re
import numpy as np
import pandas as pd
# the Naive Bayes model
from sklearn.naive_bayes import MultinomialNB
# function to split the data for cross-validation
from sklearn.model_selection import train_test_split
# function for transforming documents into counts
from sklearn.feature_extraction.text import CountVectorizer
# function for encoding categories
from sklearn.preprocessing import LabelEncoder
# grab the data
news = pd.read_csv("/Users/helloworld/Downloads/NewsAggregatorDataset/newsCorpora.csv",encoding='latin-1')
news.head()
def normalize_text(s):
s = s.lower()
# remove punctuation that is not word-internal (e.g., hyphens, apostrophes)
s = re.sub('\s\W',' ',s)
s = re.sub('\W\s',' ',s)
# make sure we didn't introduce any double spaces
s = re.sub('\s+',' ',s)
return s
news['TEXT'] = [normalize_text(s) for s in news['TITLE']]
# pull the data into vectors
vectorizer = CountVectorizer()
x = vectorizer.fit_transform(news['TEXT'])
encoder = LabelEncoder()
y = encoder.fit_transform(news['CATEGORY'])
# split into train and test sets
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
nb = MultinomialNB()
nb.fit(x_train, y_train)
所以我的问题是,我如何提供一组新数据(例如Just news heading)并告诉程序使用python sklearn命令预测新闻类别?
P.S。我的训练数据如下:
答案 0 :(得分:3)
您应该使用训练数据训练模型(就像您一样),然后您应该预测使用新数据(测试数据)。
执行以下操作:
nb = MultinomialNB()
nb.fit(x_train, y_train)
y_predicted = nb.predict(x_test)
现在,如果您想根据**准确度评估预测,您可以执行以下操作:**
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_predicted)
同样,您可以计算其他指标。
最后,我们可以看到所有可用的指标here !
编辑1
键入时:
y_predicted = nb.predict(x_test)
y_predicted
将包含与您的类别对应的数值。
要投射这些值,获取标签,您可以执行以下操作:
y_predicted_labels = encoder.inverse_transform(y_predicted)
答案 1 :(得分:1)
你非常接近。只需要两行代码。使用此链接,使用Sci Kit解释Naives Bayes, https://www.digitalocean.com/community/tutorials/how-to-build-a-machine-learning-classifier-in-python-with-scikit-learn
您的问题的简短答案如下,导入准确度函数,
from sklearn.metrics import accuracy_score
使用预测函数
测试模型preds = nb.predict(x_test)
然后测试准确度
print(accuracy_score(y_test, preds))