scikit-学习-使用单个字符串与RandomForestClassifier.predict()?

时间:2018-07-21 21:39:48

标签: scikit-learn text-classification

我是sklearn假人...我正在尝试从装有文本,标签的RandomForestClassifier()中预测给定字符串的标签。

很明显,我不知道如何对单个字符串使用predict()。我使用reshape()的原因是因为我前一段时间收到此错误“如果您的数据具有单个功能或使用array.reshape(1,-1),请使用array.reshape(-1,1)对数据进行重塑。如果它包含一个样本。”

如何预测单个文本字符串的标签?

脚本:

#!/usr/bin/env python
''' Read a txt file consisting of '<label>: <long string of text>'
    to use as a model for predicting the label for a string
'''

from argparse import ArgumentParser
import json
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder


def main(args):
    '''
    args: Arguments obtained by _Get_Args()
    '''

    print('Loading data...')
    # Load data from args.txtfile and split the lines into
    # two lists (labels, texts).
    data = open(args.txtfile).readlines()
    labels, texts = ([], [])
    for line in data:
        label, text = line.split(': ', 1)
        labels.append(label)
        texts.append(text)

    # Print a list of unique labels
    print(json.dumps(list(set(labels)), indent=4))

    # Instantiate a CountVectorizer class and git the texts
    # and labels into it.
    cv = CountVectorizer(
            stop_words='english',
            strip_accents='unicode',
            lowercase=True,
            )
    matrix = cv.fit_transform(texts)
    encoder = LabelEncoder()
    labels = encoder.fit_transform(labels)
    rf = RandomForestClassifier()
    rf.fit(matrix, labels)

    # Try to predict the label for args.string.
    prediction = Predict_Label(args.string, cv, rf)
    print(prediction)


def Predict_Label(string, cv, rf):
    '''
    string: str() - A string of text
    cv: The CountVectorizer class
    rf: The RandomForestClassifier class
    '''

    matrix = cv.fit_transform([string])
    matrix = matrix.reshape(1, -1)
    try:
        prediction = rf.predict(matrix)
    except Exception as E:
        print(str(E))
    else:
        return prediction


def _Get_Args():
    parser = ArgumentParser(description='Learn labels from text')
    parser.add_argument('-t', '--txtfile', required=True)
    parser.add_argument('-s', '--string', required=True)
    return parser.parse_args()


if __name__ == '__main__':
    args = _Get_Args()
    main(args)

实际的学习数据文本文件长43663行,但其中的一个样本在small_list.txt中,该文件由以下各行组成,格式为:<label>: <long text string>

在异常输出中记录该错误:

$ ./learn.py -t small_list.txt -s 'This is a string that might have something to do with phishing or fraud'
Loading data...
[
    "Vulnerabilities__Unknown",
    "Vulnerabilities__MSSQL Browsing Service",
    "Fraud__Phishing",
    "Fraud__Copyright/Trademark Infringement",
    "Attacks and Reconnaissance__Web Attacks",
    "Vulnerabilities__Vulnerable SMB",
    "Internal Report__SBL Notify",
    "Objectionable Content__Russian Federation Objectionable Material",
    "Malicious Code/Traffic__Malicious URL",
    "Spam__Marketing Spam",
    "Attacks and Reconnaissance__Scanning",
    "Malicious Code/Traffic__Unknown",
    "Attacks and Reconnaissance__SSH Brute Force",
    "Spam__URL in Spam",
    "Vulnerabilities__Vulnerable Open Memcached",
    "Malicious Code/Traffic__Sinkhole",
    "Attacks and Reconnaissance__SMTP Brute Force",
    "Illegal content__Child Pornography"
]
Number of features of the model must match the input. Model n_features is 2070 and input n_features is 3 
None

1 个答案:

答案 0 :(得分:0)

您需要获取第一个CountVectorizer(cv)的词汇,并在进行预测之前用于转换新的单个文本。

...

cv = CountVectorizer(
        stop_words='english',
        strip_accents='unicode',
        lowercase=True,
        )

matrix = cv.fit_transform(texts)
encoder = LabelEncoder()
labels = encoder.fit_transform(labels)
rf = RandomForestClassifier()
rf.fit(matrix, labels)

# Try to predict the label for args.string.
cv_new = CountVectorizer(
        stop_words='english',
        strip_accents='unicode',
        lowercase=True,
        vocabulary=cv.vocabulary_
        )
prediction = Predict_Label(args.string, cv_new, rf)
print(prediction)

...