从Sparklyr中提取和可视化模型树

时间:2018-11-02 18:14:14

标签: r apache-spark random-forest decision-tree sparklyr

是否有人对如何将Sparklyr的ml_decision_tree_classifier,ml_gbt_classifier或ml_random_forest_classifier模型中的树信息转换为a。)其他R树相关库可以理解的格式以及(最终)b。)进行可视化处理?非技术消耗的树木有哪些?这将包括从向量汇编器期间生成的替换字符串索引值转换回实际特征名称的功能。

下面的代码是从a sparklyr blog post中自由复制的,以提供示例:

library(sparklyr)
library(dplyr)

# If needed, install Spark locally via `spark_install()`
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)

# split the data into train and validation sets
iris_data <- iris_tbl %>%
  sdf_partition(train = 2/3, validation = 1/3, seed = 123)


iris_pipeline <- ml_pipeline(sc) %>%
  ft_dplyr_transformer(
    iris_data$train %>%
      mutate(Sepal_Length = log(Sepal_Length),
             Sepal_Width = Sepal_Width ^ 2)
  ) %>%
  ft_string_indexer("Species", "label")

iris_pipeline_model <- iris_pipeline %>%
  ml_fit(iris_data$train)

iris_vector_assembler <- ft_vector_assembler(
  sc, 
  input_cols = setdiff(colnames(iris_data$train), "Species"), 
  output_col = "features"
)
random_forest <- ml_random_forest_classifier(sc,features_col = "features")

# obtain the labels from the fitted StringIndexerModel
iris_labels <- iris_pipeline_model %>%
  ml_stage("string_indexer") %>%
  ml_labels()

# IndexToString will convert the predicted numeric values back to class labels
iris_index_to_string <- ft_index_to_string(sc, "prediction", "predicted_label", 
                                      labels = iris_labels)

# construct a pipeline with these stages
iris_prediction_pipeline <- ml_pipeline(
  iris_pipeline, # pipeline from previous section
  iris_vector_assembler, 
  random_forest,
  iris_index_to_string
)

# fit to data and make some predictions
iris_prediction_model <- iris_prediction_pipeline %>%
  ml_fit(iris_data$train)
iris_predictions <- iris_prediction_model %>%
  ml_transform(iris_data$validation)
iris_predictions %>%
  select(Species, label:predicted_label) %>%
  glimpse()

在根据here的建议进行反复试验之后,我能够以“ if / else”格式将基本决策树的格式打印为字符串:

model_stage <- iris_prediction_model$stages[[3]]

spark_jobj(model_stage) %>% invoke(., "toDebugString") %>% cat()
##print out below##
RandomForestClassificationModel (uid=random_forest_classifier_5c6a1934c8e) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 4.95)
      If (feature 3 <= 1.65)
       Predict: 0.0
      Else (feature 3 > 1.65)
       If (feature 0 <= 1.7833559100698644)
        Predict: 0.0
       Else (feature 0 > 1.7833559100698644)
        Predict: 2.0
     Else (feature 2 > 4.95)
      If (feature 2 <= 5.05)
       If (feature 1 <= 6.505000000000001)
        Predict: 2.0
       Else (feature 1 > 6.505000000000001)
        Predict: 0.0
      Else (feature 2 > 5.05)
       Predict: 2.0
  Tree 1 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.75)
      If (feature 1 <= 5.0649999999999995)
       If (feature 3 <= 1.05)
        Predict: 0.0
       Else (feature 3 > 1.05)
        If (feature 0 <= 1.8000241202036602)
         Predict: 2.0
        Else (feature 0 > 1.8000241202036602)
         Predict: 0.0
      Else (feature 1 > 5.0649999999999995)
       If (feature 0 <= 1.8000241202036602)
        Predict: 0.0
       Else (feature 0 > 1.8000241202036602)
        If (feature 2 <= 5.05)
         Predict: 0.0
        Else (feature 2 > 5.05)
         Predict: 2.0
     Else (feature 3 > 1.75)
      Predict: 2.0
  Tree 2 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 0 <= 1.7664051342320237)
      Predict: 0.0
     Else (feature 0 > 1.7664051342320237)
      If (feature 3 <= 1.45)
       If (feature 2 <= 4.85)
        Predict: 0.0
       Else (feature 2 > 4.85)
        Predict: 2.0
      Else (feature 3 > 1.45)
       If (feature 3 <= 1.65)
        If (feature 1 <= 8.125)
         Predict: 2.0
        Else (feature 1 > 8.125)
         Predict: 0.0
       Else (feature 3 > 1.65)
        Predict: 2.0
  Tree 3 (weight 1.0):
    If (feature 0 <= 1.6675287895788053)
     If (feature 2 <= 2.5)
      Predict: 1.0
     Else (feature 2 > 2.5)
      Predict: 0.0
    Else (feature 0 > 1.6675287895788053)
     If (feature 3 <= 1.75)
      If (feature 3 <= 1.55)
       If (feature 1 <= 7.025)
        If (feature 2 <= 4.55)
         Predict: 0.0
        Else (feature 2 > 4.55)
         Predict: 2.0
       Else (feature 1 > 7.025)
        Predict: 0.0
      Else (feature 3 > 1.55)
       If (feature 2 <= 5.05)
        Predict: 0.0
       Else (feature 2 > 5.05)
        Predict: 2.0
     Else (feature 3 > 1.75)
      Predict: 2.0
  Tree 4 (weight 1.0):
    If (feature 2 <= 4.85)
     If (feature 2 <= 2.5)
      Predict: 1.0
     Else (feature 2 > 2.5)
      Predict: 0.0
    Else (feature 2 > 4.85)
     If (feature 2 <= 5.05)
      If (feature 0 <= 1.8484238118815566)
       Predict: 2.0
      Else (feature 0 > 1.8484238118815566)
       Predict: 0.0
     Else (feature 2 > 5.05)
      Predict: 2.0
  Tree 5 (weight 1.0):
    If (feature 2 <= 1.65)
     Predict: 1.0
    Else (feature 2 > 1.65)
     If (feature 3 <= 1.65)
      If (feature 0 <= 1.8325494627242664)
       Predict: 0.0
      Else (feature 0 > 1.8325494627242664)
       If (feature 2 <= 4.95)
        Predict: 0.0
       Else (feature 2 > 4.95)
        Predict: 2.0
     Else (feature 3 > 1.65)
      Predict: 2.0
  Tree 6 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 5.05)
      If (feature 3 <= 1.75)
       Predict: 0.0
      Else (feature 3 > 1.75)
       Predict: 2.0
     Else (feature 2 > 5.05)
      Predict: 2.0
  Tree 7 (weight 1.0):
    If (feature 3 <= 0.55)
     Predict: 1.0
    Else (feature 3 > 0.55)
     If (feature 3 <= 1.65)
      If (feature 2 <= 4.75)
       Predict: 0.0
      Else (feature 2 > 4.75)
       Predict: 2.0
     Else (feature 3 > 1.65)
      If (feature 2 <= 4.85)
       If (feature 0 <= 1.7833559100698644)
        Predict: 0.0
       Else (feature 0 > 1.7833559100698644)
        Predict: 2.0
      Else (feature 2 > 4.85)
       Predict: 2.0
  Tree 8 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.85)
      If (feature 2 <= 4.85)
       Predict: 0.0
      Else (feature 2 > 4.85)
       If (feature 0 <= 1.8794359129669855)
        Predict: 2.0
       Else (feature 0 > 1.8794359129669855)
        If (feature 3 <= 1.55)
         Predict: 0.0
        Else (feature 3 > 1.55)
         Predict: 0.0
     Else (feature 3 > 1.85)
      Predict: 2.0
  Tree 9 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 4.95)
      Predict: 0.0
     Else (feature 2 > 4.95)
      Predict: 2.0
  Tree 10 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 2 <= 4.95)
      Predict: 0.0
     Else (feature 2 > 4.95)
      If (feature 2 <= 5.05)
       If (feature 3 <= 1.55)
        Predict: 2.0
       Else (feature 3 > 1.55)
        If (feature 3 <= 1.75)
         Predict: 0.0
        Else (feature 3 > 1.75)
         Predict: 2.0
      Else (feature 2 > 5.05)
       Predict: 2.0
  Tree 11 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 2 <= 5.05)
      If (feature 2 <= 4.75)
       Predict: 0.0
      Else (feature 2 > 4.75)
       If (feature 3 <= 1.75)
        Predict: 0.0
       Else (feature 3 > 1.75)
        Predict: 2.0
     Else (feature 2 > 5.05)
      Predict: 2.0
  Tree 12 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.75)
      If (feature 3 <= 1.35)
       Predict: 0.0
      Else (feature 3 > 1.35)
       If (feature 0 <= 1.695573522904327)
        Predict: 0.0
       Else (feature 0 > 1.695573522904327)
        If (feature 1 <= 8.125)
         Predict: 2.0
        Else (feature 1 > 8.125)
         Predict: 0.0
     Else (feature 3 > 1.75)
      If (feature 0 <= 1.7833559100698644)
       Predict: 0.0
      Else (feature 0 > 1.7833559100698644)
       Predict: 2.0
  Tree 13 (weight 1.0):
    If (feature 3 <= 0.55)
     Predict: 1.0
    Else (feature 3 > 0.55)
     If (feature 2 <= 4.95)
      If (feature 2 <= 4.75)
       Predict: 0.0
      Else (feature 2 > 4.75)
       If (feature 0 <= 1.8000241202036602)
        If (feature 1 <= 9.305)
         Predict: 2.0
        Else (feature 1 > 9.305)
         Predict: 0.0
       Else (feature 0 > 1.8000241202036602)
        Predict: 0.0
     Else (feature 2 > 4.95)
      Predict: 2.0
  Tree 14 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 3 <= 1.65)
      If (feature 3 <= 1.45)
       Predict: 0.0
      Else (feature 3 > 1.45)
       If (feature 2 <= 4.95)
        Predict: 0.0
       Else (feature 2 > 4.95)
        Predict: 2.0
     Else (feature 3 > 1.65)
      If (feature 0 <= 1.7833559100698644)
       If (feature 0 <= 1.7664051342320237)
        Predict: 2.0
       Else (feature 0 > 1.7664051342320237)
        Predict: 0.0
      Else (feature 0 > 1.7833559100698644)
       Predict: 2.0
  Tree 15 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 3 <= 1.75)
      If (feature 2 <= 4.95)
       Predict: 0.0
      Else (feature 2 > 4.95)
       If (feature 1 <= 8.125)
        Predict: 2.0
       Else (feature 1 > 8.125)
        If (feature 0 <= 1.9095150692894909)
         Predict: 0.0
        Else (feature 0 > 1.9095150692894909)
         Predict: 2.0
     Else (feature 3 > 1.75)
      Predict: 2.0
  Tree 16 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 0 <= 1.7491620461964392)
      Predict: 0.0
     Else (feature 0 > 1.7491620461964392)
      If (feature 3 <= 1.75)
       If (feature 2 <= 4.75)
        Predict: 0.0
       Else (feature 2 > 4.75)
        If (feature 0 <= 1.8164190316151556)
         Predict: 2.0
        Else (feature 0 > 1.8164190316151556)
         Predict: 0.0
      Else (feature 3 > 1.75)
       Predict: 2.0
  Tree 17 (weight 1.0):
    If (feature 0 <= 1.695573522904327)
     If (feature 2 <= 1.65)
      Predict: 1.0
     Else (feature 2 > 1.65)
      Predict: 0.0
    Else (feature 0 > 1.695573522904327)
     If (feature 2 <= 4.75)
      If (feature 2 <= 2.5)
       Predict: 1.0
      Else (feature 2 > 2.5)
       Predict: 0.0
     Else (feature 2 > 4.75)
      If (feature 3 <= 1.75)
       If (feature 1 <= 5.0649999999999995)
        Predict: 2.0
       Else (feature 1 > 5.0649999999999995)
        If (feature 3 <= 1.65)
         Predict: 0.0
        Else (feature 3 > 1.65)
         Predict: 0.0
      Else (feature 3 > 1.75)
       Predict: 2.0
  Tree 18 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.65)
      Predict: 0.0
     Else (feature 3 > 1.65)
      If (feature 0 <= 1.7833559100698644)
       Predict: 0.0
      Else (feature 0 > 1.7833559100698644)
       Predict: 2.0
  Tree 19 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 4.95)
      If (feature 1 <= 8.705)
       Predict: 0.0
      Else (feature 1 > 8.705)
       If (feature 2 <= 4.85)
        Predict: 0.0
       Else (feature 2 > 4.85)
        If (feature 0 <= 1.8164190316151556)
         Predict: 2.0
        Else (feature 0 > 1.8164190316151556)
         Predict: 0.0
     Else (feature 2 > 4.95)
      Predict: 2.0

如您所见,这种格式对于传入我见过的许多精美的决策树图形可视化方法(例如revolution analyticsstatmethods)来说并不是最佳选择

1 个答案:

答案 0 :(得分:4)

截止到今天(Spark 2.4.0版本已经批准并等待官方宣布),您最好的选择*可能是{{3},而无需使用复杂的第三方工具(例如MLeap)。 },然后读回save the model

ml_stage(iris_prediction_model, "random_forest") %>% 
  ml_save("/tmp/model")

rf_spec <- spark_read_parquet(sc, "rf", "/tmp/model/data/")

结果将是具有以下架构的Spark DataFrame

rf_spec %>% 
  spark_dataframe() %>% 
  invoke("schema") %>% invoke("treeString") %>% 
  cat(sep = "\n")
root
 |-- treeID: integer (nullable = true)
 |-- nodeData: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- prediction: double (nullable = true)
 |    |-- impurity: double (nullable = true)
 |    |-- impurityStats: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- gain: double (nullable = true)
 |    |-- leftChild: integer (nullable = true)
 |    |-- rightChild: integer (nullable = true)
 |    |-- split: struct (nullable = true)
 |    |    |-- featureIndex: integer (nullable = true)
 |    |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |    |-- element: double (containsNull = true)
 |    |    |-- numCategories: integer (nullable = true)

提供有关所有节点和拆分的信息。

可以使用列元数据检索功能映射:

meta <- iris_predictions %>% 
    select(features) %>% 
    spark_dataframe() %>% 
    invoke("schema") %>% invoke("apply", 0L) %>% 
    invoke("metadata") %>% 
    invoke("getMetadata", "ml_attr") %>% 
    invoke("getMetadata", "attrs") %>% 
    invoke("json") %>%
    jsonlite::fromJSON() %>% 
    dplyr::bind_rows() %>% 
    copy_to(sc, .) %>%
    rename(featureIndex = idx)

meta
# Source: spark<?> [?? x 2]
  featureIndex name        
*        <int> <chr>       
1            0 Sepal_Length
2            1 Sepal_Width 
3            2 Petal_Length
4            3 Petal_Width 

以及已经获取的标签映射:

labels <- tibble(prediction = seq_along(iris_labels) - 1, label = iris_labels) %>%
  copy_to(sc, .)

最后,您可以将所有这些结合起来:

full_rf_spec <- rf_spec %>% 
  spark_dataframe() %>% 
  invoke("selectExpr", list("treeID", "nodeData.*", "nodeData.split.*")) %>% 
  sdf_register() %>% 
  select(-split, -impurityStats) %>% 
  left_join(meta, by = "featureIndex") %>% 
  left_join(labels, by = "prediction")

full_rf_spec
# Source: spark<?> [?? x 12]
   treeID    id prediction impurity    gain leftChild rightChild featureIndex
 *  <int> <int>      <dbl>    <dbl>   <dbl>     <int>      <int>        <int>
 1      0     0          1   0.636   0.379          1          2            2
 2      0     1          1   0      -1             -1         -1           -1
 3      0     2          0   0.440   0.367          3          8            2
 4      0     3          0   0.0555  0.0269         4          5            3
 5      0     4          0   0      -1             -1         -1           -1
 6      0     5          0   0.5     0.5            6          7            0
 7      0     6          0   0      -1             -1         -1           -1
 8      0     7          2   0      -1             -1         -1           -1
 9      0     8          2   0.111   0.0225         9         12            2
10      0     9          2   0.375   0.375         10         11            1
# ... with more rows, and 4 more variables: leftCategoriesOrThreshold <list>,
#   numCategories <int>, name <chr>, label <chr>

treeID收集并分隔的

应该提供足够的信息**,以模仿树状对象(您可以通过查看the specification文档和/或{ {1}}建立unclass模型。rpart::rpart.object所需的工作量较少,但其绘图实用程序远非令人印象深刻),并且可以构建出不错的绘图。

另一种方法是使用tree::tree并使用此表示将数据导出到PMML。

您还可以选中Sparklyr2PMML,它建议使用第三方Python软件包来解决相同的问题。

如果您不需要任何花哨的内容,可以使用rpart创建一个粗略的情节:

igraph

How do I visualise / plot a decision tree in Apache Spark (PySpark 1.4.1)?


*它将在不久的将来得到改进,使用新引入的格式不可知ML编写器API(该模型已经支持某些模型的PMML编写器。希望会出现新的模型和格式)。

**如果您使用分类功能,则可能需要将library(igraph) gframe <- full_rf_spec %>% filter(treeID == 0) %>% # Take the first tree mutate( leftCategoriesOrThreshold = ifelse( size(leftCategoriesOrThreshold) == 1, # Continuous variable case concat("<= ", round(concat_ws("", leftCategoriesOrThreshold), 3)), # Categorical variable case. Decoding variables might be involved # but can be achieved if needed, using column metadata or indexer labels concat("in {", concat_ws(",", leftCategoriesOrThreshold), "}") ), name = coalesce(name, label)) %>% select( id, label, impurity, gain, leftChild, rightChild, leftCategoriesOrThreshold, name) %>% collect() vertices <- gframe %>% rename(label = name, name = id) edges <- gframe %>% transmute(from = id, to = leftChild, label = leftCategoriesOrThreshold) %>% union_all(gframe %>% select(from = id, to = rightChild)) %>% filter(to != -1) g <- igraph::graph_from_data_frame(edges, vertices = vertices) plot( g, layout = layout_as_tree(g, root = c(1)), vertex.shape = "rectangle", vertex.size = 45) 映射到相应的索引级别。

如果特征向量包含分类变量,则leftCategoriesOrThreshold的输出将包含jsonlite::fromJSON()组。例如,如果您将具有三个级别的列nominal编入索引,并且在第一位置进行了汇编,则将是这样的:

foo

其中$nominal vals idx name 1 a, b, c 1 foo 列是可变长度向量的列表。

vals
length(meta$nominal$vals[[1]])

标签对应于此结构的索引,因此在示例中:

  • [1] 3 的标签为0.0(不是标签是双精度浮点数,并且编号从0.0开始)
  • a的标签为1.0

,依此类推,如果您用b等于leftCategoriesOrThreshold进行拆分,则意味着拆分位于标签c(0.0, 2.0)上。

请注意,如果存在分类数据,则可能需要在调用{"a", "c"}之前对其进行处理-到目前为止,它似乎还不支持复杂的字段。

在Spark <= 2.3中,您将必须使用R代码进行映射(在本地结构上,某些copy_to应该可以)。在Spark 2.4(purrr AFAIK中尚不支持)中,使用Spark的JSON阅读器直接读取元数据并使用其高阶函数进行映射可能会更容易。