如何在Spark Naive Bayes分类器中获取类的概率?

时间:2015-08-05 20:46:45

标签: apache-spark apache-spark-mllib naivebayes

我在Spark中训练NaiveBayesModel,但是当我使用它来预测新实例时,我需要获得每个类的概率。我查看了NaiveBayesModel中预测函数的代码,并提出以下代码:

val thetaMatrix = new DenseMatrix (model.labels.length,model.theta(0).length,model.theta.flatten,true)
val piVector = new DenseVector(model.pi)
//val prob = thetaMatrix.multiply(test.features)

val x = test.map {p =>       
  val prob = thetaMatrix.multiply(p.features)          
  BLAS.axpy(1.0, piVector, prob)
  prob
}

这是否正常?第BLAS.axpy(1.0, piVector, prob)行不断给我一个错误,即“' axpy'找不到。

1 个答案:

答案 0 :(得分:2)

在最近的pull-request中,这被添加到Spark主干中,并将在Spark 1.5中发布(关闭SPARK-4362)。你可以这样打电话

def predictProbabilities(testData: RDD[Vector]): RDD[Vector]

def predictProbabilities(testData: Vector): Vector