我有一些示例代码只能在一台机器上运行。经过一些测试,我发现工作的机器运行R 3.4.2而其他一切运行3.4.3。
经过一些工作,我发现你访问ml_decision_tree参数的方式有所改变。我正试着拿到标签。这是旧方法:
model_iris$model.parameters$labels
这不再适用了。如果你在脚本的其余部分的上下文中运行它,我得到一个null。我已经尝试实际查看列表对象以确定层次结构标签的存储位置,我可以看到它们,但无论我做什么,我似乎都无法深入了解它们。
以下是整个脚本的一个版本:
library(tidyverse)
library(sparklyr)
library(Rcpp)
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)
partition_iris <- sdf_partition(import_iris, training=0.5, testing=0.5)
sdf_register(partition_iris, c("spark_iris_training", "spark_iris_test"))
tidy_iris <- tbl(sc, "spark_iris_training") %>%
select(Species, Petal_Length, Petal_Width)
model_iris <- tidy_iris %>%
ml_decision_tree(response="Species", features=c("Petal_Length", "Petal_Width"))
test_iris <- tbl(sc, "spark_iris_test")
pred_iris <- sdf_predict(model_iris, test_iris) %>%
collect
library(ggplot2)
pred_iris %>%
inner_join(data.frame(prediction=0:2, lab=model_iris$model.parameters$labels)) %>%
ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()
编辑:我正在运行的软件包版本似乎有所不同。 工作代码运行闪烁0.6.3。破碎的版本是0.7.0-9004。
答案 0 :(得分:1)
model_iris$model.parameters$labels
访问 model_iris$.index_labels
。
您可以改为运行:
pred_iris %>%
inner_join(data.frame(prediction=0:2, lab=model_iris$.index_labels)) %>%
ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()
但是,由于model_iris$.index_labels
是内部的,为防止代码在将来中断,我们应该从原始数据集或预测数据框中获取标签:
pred_iris %>%
inner_join(data.frame(prediction=0:2, lab=unique(iris$Species))) %>%
ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()
,或者
pred_iris %>%
inner_join(data.frame(prediction=0:2, lab=unique(pred_iris$predicted_label))) %>%
ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()
答案 1 :(得分:0)
pred_iris
应该有predicted_label
列,其中包含您需要的内容。您是否还有其他需要从模型对象获取标签的用例?