我正在与sparklyr
包一起使用Spark中的数据。我正在建立一个逻辑回归模型,并且在弄清楚如何从模型上的ml_predict()
生成的预测数据帧的概率列中获取每个类别的概率时遇到了一些麻烦。
以下是一些简短的示例代码,演示了我在做什么:
library(sparklyr)
library(dplyr)
sc <- spark_connect(master = "local[1]", version = "2.3.2")
iris_sc <- copy_to(sc, iris)
modelPipeline <- ml_pipeline(sc) %>%
ft_r_formula(Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width) %>%
ml_logistic_regression()
modelFit <- ml_fit(modelPipeline, iris_sc)
predictions <- ml_predict(modelFit, iris_sc)
由reprex package(v0.2.1)于2018-10-30创建
这将产生一个Spark数据帧,其中包含名为probability
的列,该列为
Spark中的org.apache.spark.ml.linalg.VectorUDT
数据类型,每行三个元素,代表三个可能类别中每个类别的模型概率预测。
如何使用sparklyr
从此对象中获取这些值之一?当然,probability[1]
之类的东西在sparklyr
中不起作用,并且我在dplyr
或sparklyr
中找不到可能有用的函数。