我目前正在处理需要决策树分类的数据集。使用mllib example中的示例,我创建的DataFrame是SparseVector。我知道可以提取根节点的杂质并进行预测,但是,我希望能够获得该节点的列名。
假设我的DataFrame看起来像这样:
+-----|-------|-------|-------+
| id | col_1 | col_2 | col_3 |
+-----|-------|-------|-------+
| 0 | 1.0 | 0.0 | 2.0 |
| 1 | 2.0 | 1.0 | 0.0 |
| 3 | 2.0 | 2.0 | 1.0 |
+-----|-------|-------|-------+
我的最终数据集比这更多的列,只是为了显示我正在使用的示例。
然后我使用示例代码将其转换为VectorIndexer,这将给我类似的东西:
+-----|--------------------------------+
| id | features |
+-----|--------------------------------+
| 0 | (3, [0, 2], [1.0, 2.0]) |
| 1 | (3, [0, 1]. [2.0, 1.0]) |
| 3 | (3, [0, 1, 2], [2.0, 2.0, 1.0] |
+-----|--------------------------------+
最后,我有了学到的分类树,可以从树中获取根节点。但是我要做的是获取与根节点关联的列名称。
val data = spark.read.format("libsvm").load("./src/main/resources/sample_libsvm_data.txt")
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
.fit(data)
val (trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
println("Pipeline stages: ")
pipeline.getStages.foreach(println)
val model = pipeline.fit(trainingData)
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)
val root = treeModel.rootNode
println("Root: " + root.toString)
这会给我一些例如:
Trained classification tree model:
DecisionTreeClassificationModel (uid=dtc_5eadf281a5fc) of depth 4 with 13 nodes
If (feature 32 in {1.0,2.0,3.0,4.0})
If (feature 4 in {0.0})
Predict: 2.0
Else (feature 4 not in {0.0})
Predict: 1.0
Else (feature 32 not in {1.0,2.0,3.0,4.0})
If (feature 37 in {0.0,2.0})
If (feature 1 in {1.0})
Predict: 1.0
Else (feature 1 not in {1.0})
If (feature 4 in {0.0,3.0})
Predict: 0.0
Else (feature 4 not in {0.0,3.0})
Predict: 2.0
Else (feature 37 not in {0.0,2.0})
If (feature 3 in {0.0})
Predict: 2.0
Else (feature 3 not in {0.0})
Predict: 1.0
Root: InternalNode(prediction = 0.0, impurity = 0.6428571428571429, split = org.apache.spark.ml.tree.CategoricalSplit@a0058edd)
如您所见,我可以看到树的根是feature 32
,但我希望能够获得它的列名。
在此先感谢您的帮助。