spark naive bayes预测分析

时间:2016-10-09 09:10:09

标签: apache-spark machine-learning apache-spark-mllib naivebayes

我使用Naive Bayes进行文本分类

以下是我用于理解朴素贝叶斯的链接

https://www.analyticsvidhya.com/blog/2015/09/naive-bayes-explained/

虽然我得到了很好的预测结果,但我无法理解失败案例的原因

我使用predictProbabilities来测量特征的概率来理解其原因 正确的预测

以下是我的理解,基于此我试图找出在某些情况下预测错误的原因

假设我的测试数据如下(我有大约100000条培训记录)

Text                                      Classification
There is a murder in town                 - HIGH SEVERITY
The old women was murdered                - HIGH SEVERITY
Boy was hit by ball in street             - LOW SEVERITY
John sprained his ankle while playing     - LOW SEVERITY

现在当我对下面的句子进行预测时 "城市发生了一起谋杀案。 - 我希望模型可以预测HIGH SEVERITY。 但有时模型预测LOW SEVERITY

我拿出了所有含有相同字词的文字,并试图弄清楚为什么会这样。 如果我使用https://www.analyticsvidhya.com/blog/2015/09/naive-bayes-explained/中的公式手动计算概率,则应该正确预测。 但我找不到任何预测错误的原因。

如果我遗漏了任何重要信息,请告诉我

下面添加的代码片段

我的训练数据框由三列组成" id" ,"风险","标签"

该文本已使用stanford NLP

进行了调整
    // TOKENIZE DATA

            regexTokenizer = new RegexTokenizer()
                      .setInputCol("text")
                      .setOutputCol("words")
                      .setPattern("\\W"); 

            DataFrame tokenized = regexTokenizer.transform(trainingRiskData);

    // REMOVE STOP WORDS

            remover = new StopWordsRemover().setInputCol("words").setOutputCol("filtered");

            DataFrame stopWordsRemoved = remover.transform(tokenized);

// COMPUTE TERM FREQUENCY USING HASHING

        int numFeatures = 20;
        hashingTF = new HashingTF().setInputCol("filtered").setOutputCol("rawFeatures")
                .setNumFeatures(numFeatures);
        DataFrame rawFeaturizedData = hashingTF.transform(stopWordsRemoved);

IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
        idfModel = idf.fit(rawFeaturizedData);

        DataFrame featurizedData = idfModel.transform(rawFeaturizedData);

JavaRDD<LabeledPoint> labelledJavaRDD = featurizedData.select("label", "features").toJavaRDD()
                .map(new Function<Row, LabeledPoint>() {

                    @Override
                    public LabeledPoint call(Row arg0) throws Exception {
                        LabeledPoint labeledPoint = new LabeledPoint(new Double(arg0.get(0).toString()),
                                (org.apache.spark.mllib.linalg.Vector) arg0.get(1));
                        return labeledPoint;
                    }
                });

NaiveBayes naiveBayes = new NaiveBayes(1.0, "multinomial");
        NaiveBayesModel naiveBayesModel = naiveBayes.train(labelledJavaRDD.rdd(), 1.0);

构建训练模型后,测试数据将通过相同的转换传递,并使用下面的代码进行预测 第3列是测试数据框中的标签。 第7列是测试数据框中的特征

LabeledPoint labeledPoint = new LabeledPoint(new Double(dataFrameRow.get(3).toString()),
                        (org.apache.spark.mllib.linalg.Vector) dataFrameRow.get(7));

double predictedLabel = naiveBayesModel.predict(labeledPoint.features());

0 个答案:

没有答案