有没有一种方法可以提取Spark 2.1.3 mllib(Scala)决策树分类器模型中根节点的列标签?

时间:2019-01-03 02:30:30

标签: scala apache-spark apache-spark-mllib decision-tree

我目前正在处理需要决策树分类的数据集。使用mllib example中的示例,我创建的DataFrame是SparseVector。我知道可以提取根节点的杂质并进行预测,但是,我希望能够获得该节点的列名。

假设我的DataFrame看起来像这样:

+-----|-------|-------|-------+
| id  | col_1 | col_2 | col_3 |
+-----|-------|-------|-------+
| 0   | 1.0   | 0.0   | 2.0   |
| 1   | 2.0   | 1.0   | 0.0   |
| 3   | 2.0   | 2.0   | 1.0   |
+-----|-------|-------|-------+

我的最终数据集比这更多的列,只是为了显示我正在使用的示例。

然后我使用示例代码将其转换为VectorIndexer,这将给我类似的东西:

+-----|--------------------------------+
| id  | features                       | 
+-----|--------------------------------+
| 0   | (3, [0, 2], [1.0, 2.0])        |
| 1   | (3, [0, 1]. [2.0, 1.0])        |
| 3   | (3, [0, 1, 2], [2.0, 2.0, 1.0] | 
+-----|--------------------------------+

最后,我有了学到的分类树,可以从树中获取根节点。但是我要做的是获取与根节点关联的列名称。

val data = spark.read.format("libsvm").load("./src/main/resources/sample_libsvm_data.txt")

val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)

val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
  .fit(data)

val (trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

val dt = new DecisionTreeClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")

val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)


val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
println("Pipeline stages: ")
pipeline.getStages.foreach(println)

val model = pipeline.fit(trainingData)

val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)

val root = treeModel.rootNode
println("Root: " + root.toString)

这会给我一些例如:

Trained classification tree model:
DecisionTreeClassificationModel (uid=dtc_5eadf281a5fc) of depth 4 with 13 nodes
  If (feature 32 in {1.0,2.0,3.0,4.0})
   If (feature 4 in {0.0})
    Predict: 2.0
   Else (feature 4 not in {0.0})
    Predict: 1.0
  Else (feature 32 not in {1.0,2.0,3.0,4.0})
   If (feature 37 in {0.0,2.0})
    If (feature 1 in {1.0})
     Predict: 1.0
    Else (feature 1 not in {1.0})
     If (feature 4 in {0.0,3.0})
      Predict: 0.0
     Else (feature 4 not in {0.0,3.0})
      Predict: 2.0
   Else (feature 37 not in {0.0,2.0})
    If (feature 3 in {0.0})
     Predict: 2.0
    Else (feature 3 not in {0.0})
     Predict: 1.0

Root: InternalNode(prediction = 0.0, impurity = 0.6428571428571429, split = org.apache.spark.ml.tree.CategoricalSplit@a0058edd)

如您所见,我可以看到树的根是feature 32,但我希望能够获得它的列名。

在此先感谢您的帮助。

0 个答案:

没有答案