我使用Naive Bayes进行文本分类
以下是我用于理解朴素贝叶斯的链接
https://www.analyticsvidhya.com/blog/2015/09/naive-bayes-explained/
虽然我得到了很好的预测结果,但我无法理解失败案例的原因
我使用predictProbabilities来测量特征的概率来理解其原因 正确的预测
以下是我的理解,基于此我试图找出在某些情况下预测错误的原因
假设我的测试数据如下(我有大约100000条培训记录)
Text Classification
There is a murder in town - HIGH SEVERITY
The old women was murdered - HIGH SEVERITY
Boy was hit by ball in street - LOW SEVERITY
John sprained his ankle while playing - LOW SEVERITY
现在当我对下面的句子进行预测时 "城市发生了一起谋杀案。 - 我希望模型可以预测HIGH SEVERITY。 但有时模型预测LOW SEVERITY
我拿出了所有含有相同字词的文字,并试图弄清楚为什么会这样。 如果我使用https://www.analyticsvidhya.com/blog/2015/09/naive-bayes-explained/中的公式手动计算概率,则应该正确预测。 但我找不到任何预测错误的原因。
如果我遗漏了任何重要信息,请告诉我
下面添加的代码片段
我的训练数据框由三列组成" id" ,"风险","标签"
该文本已使用stanford NLP
进行了调整 // TOKENIZE DATA
regexTokenizer = new RegexTokenizer()
.setInputCol("text")
.setOutputCol("words")
.setPattern("\\W");
DataFrame tokenized = regexTokenizer.transform(trainingRiskData);
// REMOVE STOP WORDS
remover = new StopWordsRemover().setInputCol("words").setOutputCol("filtered");
DataFrame stopWordsRemoved = remover.transform(tokenized);
// COMPUTE TERM FREQUENCY USING HASHING
int numFeatures = 20;
hashingTF = new HashingTF().setInputCol("filtered").setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
DataFrame rawFeaturizedData = hashingTF.transform(stopWordsRemoved);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
idfModel = idf.fit(rawFeaturizedData);
DataFrame featurizedData = idfModel.transform(rawFeaturizedData);
JavaRDD<LabeledPoint> labelledJavaRDD = featurizedData.select("label", "features").toJavaRDD()
.map(new Function<Row, LabeledPoint>() {
@Override
public LabeledPoint call(Row arg0) throws Exception {
LabeledPoint labeledPoint = new LabeledPoint(new Double(arg0.get(0).toString()),
(org.apache.spark.mllib.linalg.Vector) arg0.get(1));
return labeledPoint;
}
});
NaiveBayes naiveBayes = new NaiveBayes(1.0, "multinomial");
NaiveBayesModel naiveBayesModel = naiveBayes.train(labelledJavaRDD.rdd(), 1.0);
构建训练模型后,测试数据将通过相同的转换传递,并使用下面的代码进行预测 第3列是测试数据框中的标签。 第7列是测试数据框中的特征
LabeledPoint labeledPoint = new LabeledPoint(new Double(dataFrameRow.get(3).toString()),
(org.apache.spark.mllib.linalg.Vector) dataFrameRow.get(7));
double predictedLabel = naiveBayesModel.predict(labeledPoint.features());