无法确定对使用doc2vec和随机森林分类器训练的数据集进行预测所需的格式

时间:2018-07-14 17:40:42

标签: python machine-learning random-forest doc2vec

我试图基于一些预先定义的数据(推文所属的推文和类别,标记为1-16)对数据集进行预测,这些数据已使用doc2vec建立了模型并在随机森林分类器上进行了训练。我对在调用clf.predict(tweet)之前需要将数据放入哪种格式感到困惑。

import csv
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import itertools
from gensim import utils
from gensim.models import Doc2Vec
import gensim  
import numpy as np

#just making the object to put into gensim's doc2vec
class LabeledLineSentence(object):

    def __init__(self, doc_list, labels_list):
            self.labels_list = labels_list
            self.doc_list = doc_list

    def __iter__(self):
            for t, l in itertools.izip(self.doc_list, self.labels_list):
                    #change here
                    t = nltk.word_tokenize(t)
                    #end of change
                    yield gensim.models.doc2vec.LabeledSentence(t, [l])

#predefined
tweets = ["a tweet", "another tweet", ... , "a thousandth tweet"]
labels = [1, 1, ... , 16] #what category the tweet belongs to

training_data = LabeledLineSentence(tweets, labels_list)

#build the doc2vec model
model = Doc2Vec(vector_size=100, min_count=1, dm=1)
model.build_vocab(training_data)
model.train(training_data, total_examples=model.corpus_count, epochs=20)

#put tweets into classifier
train_tweets = []

for i in range(len(tweets)):
    label = labels_list[i]
    train_tweets.append(model[label])

#have to convert to numpy array because that is what clf takes
train_tweets = np.array(train_tweets)
train_labels = np.array(labels_list)

#fit classifier
clf = RandomForestClassifier().fit(train_tweets, train_labels)


#this is the data i am trying to classify into labels
test_data = ["an unseen tweet", "another unseen tweet", ... , "a thousandth unseen tweet"]

#*******change here***************
for t in test_data:
    split = nltk.word_tokenize(t)
    vect = model.infer_vector(split)
    vect = vect.reshape(1, -1)
    print clf.predict(vect)

在这个代码块的结尾,我感到很困惑。我敢肯定,我已经建立了doc2vec模型并正确地训练了分类器,但是我不确定在对clf.predict进行调用之前,应该对测试数据中的每个tweet做什么。我尝试对字符串进行标记并使用计数矢量化器,但是我不断收到有关如何将其不能转换为浮点数的错误。在将测试数据用于预测之前,还有其他方法可以处理测试数据吗?

0 个答案:

没有答案