我有一个小语料库,我想用10倍交叉验证来计算朴素贝叶斯分类器的准确性,怎么做呢。
答案 0 :(得分:26)
您的选择是自己设置或使用NLTK-Trainer之类的内容,因为NLTK doesn't directly support cross-validation for machine learning algorithms。
我建议可能只是使用另一个模块为你做这个,但如果你真的想编写自己的代码,你可以做类似以下的事情。
假设您需要 10倍,则必须将训练集划分为10
个子集,在9/10
上训练,对剩余的1/10
进行测试,并为每个子集组合(10
)执行此操作。
假设您的训练集位于名为training
的列表中,实现此目的的简单方法是,
num_folds = 10
subset_size = len(training)/num_folds
for i in range(num_folds):
testing_this_round = training[i*subset_size:][:subset_size]
training_this_round = training[:i*subset_size] + training[(i+1)*subset_size:]
# train using training_this_round
# evaluate against testing_this_round
# save accuracy
# find mean accuracy over all rounds
答案 1 :(得分:21)
实际上,不需要在最受欢迎的答案中提供的长循环迭代。分类器的选择也无关紧要(它可以是任何分类器)。
Scikit提供了cross_val_score,它可以完成所有循环。
from sklearn.cross_validation import KFold, cross_val_score
k_fold = KFold(len(y), n_folds=10, shuffle=True, random_state=0)
clf = <any classifier>
print cross_val_score(clf, X, y, cv=k_fold, n_jobs=1)
答案 2 :(得分:14)
我已将两个库和NLTK用于naivebayes sklearn进行交叉验证,如下所示:
import nltk
from sklearn import cross_validation
training_set = nltk.classify.apply_features(extract_features, documents)
cv = cross_validation.KFold(len(training_set), n_folds=10, indices=True, shuffle=False, random_state=None, k=None)
for traincv, testcv in cv:
classifier = nltk.NaiveBayesClassifier.train(training_set[traincv[0]:traincv[len(traincv)-1]])
print 'accuracy:', nltk.classify.util.accuracy(classifier, training_set[testcv[0]:testcv[len(testcv)-1]])
最后我计算了平均准确度
答案 3 :(得分:1)
修改了第二个答案:
cv = cross_validation.KFold(len(training_set), n_folds=10, shuffle=True, random_state=None)
答案 4 :(得分:1)
从Jared's answer启发,这是使用生成器的版本:
CFData
我假设您的数据集def k_fold_generator(X, y, k_fold):
subset_size = len(X) / k_fold # Cast to int if using Python 3
for k in range(k_fold):
X_train = X[:k * subset_size] + X[(k + 1) * subset_size:]
X_valid = X[k * subset_size:][:subset_size]
y_train = y[:k * subset_size] + y[(k + 1) * subset_size:]
y_valid = y[k * subset_size:][:subset_size]
yield X_train, y_train, X_valid, y_valid
有N个数据点(示例中为= 4)和D个要素(示例中为= 2)。关联的N个标签存储在X
。
y