在随机森林中访问概率对象列

时间:2019-11-02 20:54:29

标签: r sparkr

在SparkR中,我有类似的东西

rf <- spark.randomForest(train, formula, type = "classification")
pred <- predict(rf,test) 

执行

head(pred) 

输出就是您在图像中看到的

enter image description here

Converting SparkR predictions to readable format (number or string)

我如何获得概率值?

1 个答案:

答案 0 :(得分:0)

您必须使用函数values为每个对象调用一个名为sparkR.callJMethod的Java方法。

t(sapply(collect(select(pred, "probability"))$probability, 
         FUN = function(x) sparkR.callJMethod(x, "values")))

这是使用Iris数据集的完整示例。目标值为Species,具有3个级别,总共有150个数据点。

df <- createDataFrame(iris)
model <- spark.randomForest(df, Species ~ ., type = "classification")
summary(model)

predictions <- predict(model, df)

local_prob <- collect(select(predictions, "probability"))$probability

t(sapply(local_prob, FUN = function(x) sparkR.callJMethod(x, "values")))

请注意,这些预测是收集的,如果数据集很大,则可能会耗尽内存。如果是这样,则可以改用head

截断的输出:

       [,1]        [,2]        [,3]
  [1,] 0           0           1   
  [2,] 0           0           1   
  [3,] 0           0           1   
...  
[148,] 0           1           0   
[149,] 0.05        0.95        0   
[150,] 0.01805556  0.9819444   0