我是sklearn假人...我正在尝试从装有文本,标签的RandomForestClassifier()中预测给定字符串的标签。
很明显,我不知道如何对单个字符串使用predict()。我使用reshape()的原因是因为我前一段时间收到此错误“如果您的数据具有单个功能或使用array.reshape(1,-1),请使用array.reshape(-1,1)对数据进行重塑。如果它包含一个样本。”
如何预测单个文本字符串的标签?
#!/usr/bin/env python
''' Read a txt file consisting of '<label>: <long string of text>'
to use as a model for predicting the label for a string
'''
from argparse import ArgumentParser
import json
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
def main(args):
'''
args: Arguments obtained by _Get_Args()
'''
print('Loading data...')
# Load data from args.txtfile and split the lines into
# two lists (labels, texts).
data = open(args.txtfile).readlines()
labels, texts = ([], [])
for line in data:
label, text = line.split(': ', 1)
labels.append(label)
texts.append(text)
# Print a list of unique labels
print(json.dumps(list(set(labels)), indent=4))
# Instantiate a CountVectorizer class and git the texts
# and labels into it.
cv = CountVectorizer(
stop_words='english',
strip_accents='unicode',
lowercase=True,
)
matrix = cv.fit_transform(texts)
encoder = LabelEncoder()
labels = encoder.fit_transform(labels)
rf = RandomForestClassifier()
rf.fit(matrix, labels)
# Try to predict the label for args.string.
prediction = Predict_Label(args.string, cv, rf)
print(prediction)
def Predict_Label(string, cv, rf):
'''
string: str() - A string of text
cv: The CountVectorizer class
rf: The RandomForestClassifier class
'''
matrix = cv.fit_transform([string])
matrix = matrix.reshape(1, -1)
try:
prediction = rf.predict(matrix)
except Exception as E:
print(str(E))
else:
return prediction
def _Get_Args():
parser = ArgumentParser(description='Learn labels from text')
parser.add_argument('-t', '--txtfile', required=True)
parser.add_argument('-s', '--string', required=True)
return parser.parse_args()
if __name__ == '__main__':
args = _Get_Args()
main(args)
实际的学习数据文本文件长43663行,但其中的一个样本在small_list.txt中,该文件由以下各行组成,格式为:<label>: <long text string>
$ ./learn.py -t small_list.txt -s 'This is a string that might have something to do with phishing or fraud'
Loading data...
[
"Vulnerabilities__Unknown",
"Vulnerabilities__MSSQL Browsing Service",
"Fraud__Phishing",
"Fraud__Copyright/Trademark Infringement",
"Attacks and Reconnaissance__Web Attacks",
"Vulnerabilities__Vulnerable SMB",
"Internal Report__SBL Notify",
"Objectionable Content__Russian Federation Objectionable Material",
"Malicious Code/Traffic__Malicious URL",
"Spam__Marketing Spam",
"Attacks and Reconnaissance__Scanning",
"Malicious Code/Traffic__Unknown",
"Attacks and Reconnaissance__SSH Brute Force",
"Spam__URL in Spam",
"Vulnerabilities__Vulnerable Open Memcached",
"Malicious Code/Traffic__Sinkhole",
"Attacks and Reconnaissance__SMTP Brute Force",
"Illegal content__Child Pornography"
]
Number of features of the model must match the input. Model n_features is 2070 and input n_features is 3
None
答案 0 :(得分:0)
您需要获取第一个CountVectorizer(cv)的词汇,并在进行预测之前用于转换新的单个文本。
...
cv = CountVectorizer(
stop_words='english',
strip_accents='unicode',
lowercase=True,
)
matrix = cv.fit_transform(texts)
encoder = LabelEncoder()
labels = encoder.fit_transform(labels)
rf = RandomForestClassifier()
rf.fit(matrix, labels)
# Try to predict the label for args.string.
cv_new = CountVectorizer(
stop_words='english',
strip_accents='unicode',
lowercase=True,
vocabulary=cv.vocabulary_
)
prediction = Predict_Label(args.string, cv_new, rf)
print(prediction)
...