如何访问Sparklyr包中ml_decision_tree中的模型参数?

时间:2018-02-16 21:11:56

标签: r apache-spark sparklyr

我有一些示例代码只能在一台机器上运行。经过一些测试,我发现工作的机器运行R 3.4.2而其他一切运行3.4.3。

经过一些工作,我发现你访问ml_decision_tree参数的方式有所改变。我正试着拿到标签。这是旧方法:

model_iris$model.parameters$labels

这不再适用了。如果你在脚本的其余部分的上下文中运行它,我得到一个null。我已经尝试实际查看列表对象以确定层次结构标签的存储位置,我可以看到它们,但无论我做什么,我似乎都无法深入了解它们。

以下是整个脚本的一个版本:

library(tidyverse)
library(sparklyr)
library(Rcpp)
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)

partition_iris <- sdf_partition(import_iris, training=0.5, testing=0.5)

sdf_register(partition_iris, c("spark_iris_training", "spark_iris_test"))

tidy_iris <- tbl(sc, "spark_iris_training") %>%
  select(Species, Petal_Length, Petal_Width)

model_iris <- tidy_iris %>%
  ml_decision_tree(response="Species", features=c("Petal_Length", "Petal_Width"))

test_iris <- tbl(sc, "spark_iris_test")

pred_iris <- sdf_predict(model_iris, test_iris) %>%
  collect

library(ggplot2)

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=model_iris$model.parameters$labels)) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

编辑:我正在运行的软件包版本似乎有所不同。 工作代码运行闪烁0.6.3。破碎的版本是0.7.0-9004。

2 个答案:

答案 0 :(得分:1)

现在可以使用model_iris$model.parameters$labels访问

model_iris$.index_labels

您可以改为运行:

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=model_iris$.index_labels)) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

但是,由于model_iris$.index_labels是内部的,为防止代码在将来中断,我们应该从原始数据集或预测数据框中获取标签:

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=unique(iris$Species))) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

,或者

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=unique(pred_iris$predicted_label))) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

答案 1 :(得分:0)

pred_iris应该有predicted_label列,其中包含您需要的内容。您是否还有其他需要从模型对象获取标签的用例?