我想从已保存(或未保存)的DecisionTreeClassificationModel
中获取树节点的权重。但是我找不到与之相似的东西。
在不知道任何分类的情况下,模型如何实际执行分类。下面是保存在模型中的参数:
{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}
答案 0 :(得分:1)
通过使用treeWeights
:
treeWeights
返回每棵树的权重
1.5.0版中的新功能。
所以
不知道这些模型中的任何一个,模型如何实际执行分类。
权重被存储,而不是作为元数据的一部分存储。如果您有model
from pyspark.ml.classification import RandomForestClassificationModel
model: RandomForestClassificationModel = ...
并将其保存到磁盘
path: str = ...
model.save(path)
您将看到编写者创建了treesMetadata
子目录。如果加载内容(默认编写器使用Parquet):
import os
trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))
您将看到以下结构:
trees_metadata.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
其中weights
列包含由treeID
标识的树的权重。
类似地,节点数据存储在data
子目录中(例如,参见Extract and Visualize Model Trees from Sparklyr):
spark.read.parquet(os.path.join(path, "data")).printSchema()
root
|-- id: integer (nullable = true)
|-- prediction: double (nullable = true)
|-- impurity: double (nullable = true)
|-- impurityStats: array (nullable = true)
| |-- element: double (containsNull = true)
|-- gain: double (nullable = true)
|-- leftChild: integer (nullable = true)
|-- rightChild: integer (nullable = true)
|-- split: struct (nullable = true)
| |-- featureIndex: integer (nullable = true)
| |-- leftCategoriesOrThreshold: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- numCategories: integer (nullable = true)
等效信息(减去树数据和树权重)也可用于DecisionTreeClassificationModel
。