如何更改算法中的参数以获得更好的性能?

时间:2015-07-02 16:34:29

标签: python machine-learning scikit-learn classification

我在一组推文上运行了Multinomial和Bernoulli Naive Bayes,以及Linear SVC。他们在60/40分开的1000条培训推文中表现良好(分别为80%,80%,90%)。

每个算法都有可以更改的参数,我想知道是否可以通过更改参数获得更好的结果。除了培训,测试和预测之外,我对机器学习知之甚少,所以我想知道是否有人可以就我可以调整哪些参数给出一些建议。

以下是我使用的代码:

import codecs
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB,BernoulliNB
from sklearn import svm

trainfile = 'training_words.txt'
testfile = 'testing_words.txt'

word_vectorizer = CountVectorizer(analyzer='word')  
trainset = word_vectorizer.fit_transform(codecs.open(trainfile,'r','utf8'))
tags = training_labels

mnb = svm.LinearSVC() #Or any other classifier
mnb.fit(trainset, tags)
codecs.open(testfile,'r','utf8')
testset = word_vectorizer.transform(codecs.open(testfile,'r','utf8'))
results = mnb.predict(testset)

print results

1 个答案:

答案 0 :(得分:2)

您可以使用Grid Search Cross Validation通过分层K-Fold交叉验证拆分来调整模型参数。这是一个示例代码。

import codecs
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB,BernoulliNB
from sklearn import svm
from sklearn.grid_search import GridSearchCV

testfile = 'testing_words.txt'

word_vectorizer = CountVectorizer(analyzer='word')  
trainset = word_vectorizer.fit_transform(codecs.open(trainfile,'r','utf8'))
tags = training_labels


mnb = svm.LinearSVC() # or any other classifier
# check out the sklearn online docs to see what params choice we have for your
# particular choice of estimator, for SVM, C, class_weight are important ones to tune 
params_space = {'C': np.logspace(-5, 0, 10), 'class_weight':[None, 'auto']}
# build a grid search cv, n_jobs=-1 to use all your processor cores
gscv = GridSearchCV(mnb, params_space, cv=10, n_jobs=-1)
# fit the model
gscv.fit(trainset, tags)
# give a look at your best params combination and best score you have
gscv.best_estimator_
gscv.best_params_
gscv.best_score_


codecs.open(testfile,'r','utf8')
testset = word_vectorizer.transform(codecs.open(testfile,'r','utf8'))
results = gscv.predict(testset)

print results