Spark ML:DecisionTreeClassificatonModel如何知道树的权重?

时间:2019-02-26 19:28:55

标签: apache-spark pyspark apache-spark-ml

我想从已保存(或未保存)的DecisionTreeClassificationModel中获取树节点的权重。但是我找不到与之相似的东西。

在不知道任何分类的情况下,模型如何实际执行分类。下面是保存在模型中的参数:

{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}

1 个答案:

答案 0 :(得分:1)

通过使用treeWeights

  

treeWeights

     

返回每棵树的权重

     

1.5.0版中的新功能。

所以

  

不知道这些模型中的任何一个,模型如何实际执行分类。

权重被存储,而不是作为元数据的一部分存储。如果您有model

from pyspark.ml.classification import RandomForestClassificationModel

model: RandomForestClassificationModel = ...

并将其保存到磁盘

path: str = ...

model.save(path)

您将看到编写者创建了treesMetadata子目录。如果加载内容(默认编写器使用Parquet):

import os

trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))

您将看到以下结构:

trees_metadata.printSchema()
root
 |-- treeID: integer (nullable = true)
 |-- metadata: string (nullable = true)
 |-- weights: double (nullable = true)

其中weights列包含由treeID标识的树的权重。

类似地,节点数据存储在data子目录中(例如,参见Extract and Visualize Model Trees from Sparklyr):

spark.read.parquet(os.path.join(path, "data")).printSchema()     
root
 |-- id: integer (nullable = true)
 |-- prediction: double (nullable = true)
 |-- impurity: double (nullable = true)
 |-- impurityStats: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- gain: double (nullable = true)
 |-- leftChild: integer (nullable = true)
 |-- rightChild: integer (nullable = true)
 |-- split: struct (nullable = true)
 |    |-- featureIndex: integer (nullable = true)
 |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- numCategories: integer (nullable = true)

等效信息(减去树数据和树权重)也可用于DecisionTreeClassificationModel