使用sklearn

时间:2015-08-26 12:26:33

标签: python scikit-learn svm document-classification

我正在使用sklearn和支持向量机来分类文档。我希望将文档放入的类别是{课程,非课程},其中课程表示由大学专业提供的课程和非课程提供的课程的网页文本。

我为此而构建的这个类类似于这个类:

import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.grid_search import GridSearchCV
from sklearn import metrics

class TestSVG(object):
    def __init__(self):

        self.text_clf = Pipeline([('vect', CountVectorizer()),
                                  ('tfidf', TfidfTransformer()),
                                  ('clf', SGDClassifier(loss='hinge', alpha=1e-3, random_state=42)),
                                ])

        self.grid_params = {'vect__ngram_range' : [(1, 1), (1, 2), (1, 3)],
                            'tfidf__use_idf': (True, False),
                            'clf__alpha': (1e-2, 1e-3),
                           }

        self.gs_clf = GridSearchCV(self.text_clf, self.grid_params, n_jobs=-1)
        self.training_target = []
        self.training_data = []
        self.testing_data = []
        self.testing_target = []

        self.classifier = None

    def train(self, training_data, training_target):
        self.training_data = training_data
        self.training_target = training_target
        self.classifier = self.gs_clf.fit(self.training_data, self.training_target)

    def predict(self, text):    
        if isinstance(text, basestring):
            text = [text]       
        elif not isinstance(text, list):
            raise ValueError("Input for prediction must be text of a list")

        if self.classifier is None:
            raise ValueError("Classifier must be trained to make predictions.")

        return self.classifier.predict(text)

    def test(self, testing_data, testing_target):
        self.testing_data = testing_data
        self.testing_target = testing_target

        predicted = self.classifier.predict(self.testing_data)
        return np.mean(predicted == testing_target)

为了收集课程的培训数据,我写了一些webscraping课程,它们为一组网页提取文本,这些网页的基本网址是我硬编码的。

我在这一点上陷入困​​境。我最初的策略是将课程描述页面作为非课程文档。但是,因为我将文档分类为课程,然后基本上是"其他任何内容",我不确定是否应该使用相关内容或完全不相关的内容,例如非课程文档的一组预定义的维基百科页面。

我计划使用课程描述的原因是我的最终计划是使用scrapy来创建Web链接图。然后,我可以使用支持向量机遍历图表,抓取文本,找到未知的课程列表页面。我担心的是,如果svm没有经过这样的培训,我会得到误报。

非常感谢任何见解。

1 个答案:

答案 0 :(得分:0)

最好让训练数据尽可能类似于测试数据(或您预测的数据)。恕我直言,非课程数据应包括两个略有相关的文件(例如:期刊文件和完全无关的文件(体育新闻)。