如何获取Spark Decision Tree模型的节点信息

时间:2018-04-12 20:22:00

标签: python scala pyspark apache-spark-mllib apache-spark-ml

我想通过Spark MLlib的决策树获得有关生成模型的每个节点的更多详细信息。我可以使用API​​获得的最接近的是print(model.toDebugString()),返回类似这样的内容(取自PySpark doc)

  DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.0)
   Predict: 0.0
  Else (feature 0 > 0.0)
   Predict: 1.0

如何修改MLlib源代码以获得每个节点的杂质和深度? (如有必要,我如何在PySpark中调用新的Scala函数?)

2 个答案:

答案 0 :(得分:1)

不幸的是,我找不到任何方法在PySpark或Spark(Scala API)中直接访问节点 。但是有一种方法可以从根节点开始并遍历到不同的节点。

(我刚刚在这里提到了杂质,但是对于深度,我可以轻松替换impuritysubtreeDepth。)

假设决策树模型实例为dt

PySpark

root = dt.call("topNode")
root.impurity() # gives the impurity of the root node

现在,如果我们查看适用于root的方法:

dir(root)
[u'apply', u'deepCopy', u'emptyNode', u'equals', 'getClass', u'getNode', u'hashCode', u'id', 'impurity', u'impurity_$eq', u'indexToLevel', u'initializeLogIfNecessary', u'isLeaf', u'isLeaf_$eq', u'isLeftChild', u'isTraceEnabled', u'leftChildIndex', u'leftNode', u'leftNode_$eq', u'log', u'logDebug', u'logError', u'logInfo', u'logName', u'logTrace', u'logWarning', u'maxNodesInLevel', u'notify', u'notifyAll', u'numDescendants', u'org$apache$spark$internal$Logging$$log_', u'org$apache$spark$internal$Logging$$log__$eq', u'parentIndex', u'predict', u'predict_$eq', u'rightChildIndex', u'rightNode', u'rightNode_$eq', u'split', u'split_$eq', u'startIndexInLevel', u'stats', u'stats_$eq', u'subtreeDepth', u'subtreeIterator', u'subtreeToString', u'subtreeToString$default$1', u'toString', u'wait']

我们可以做到:

root.leftNode().get().impurity()

这可以在树中更深入,例如:

root.leftNode().get().rightNode().get().impurity()

自应用leftNode()rightNode()后,我们转到option,应用get或getOrElse is necessary to get to the desired节点类型。

如果你想知道我如何使用这些奇怪的方法,我必须承认,我作弊!!,即我首先研究了Scala API:

火花

以下几行与上述内容完全等效,假设dt相同,则给出相同的结果:

val root = dt.topNode
root.impurity

我们可以做到:

root.leftNode.get.impurity

这可以在树中更深入,例如:

root.leftNode.get.rightNode.get.impurity

答案 1 :(得分:0)

我将通过描述我如何使用PySpark 2.4.3来完成@mostOfMajority的回答。

根节点

给出训练有素的决策树模型,这就是获取其根节点的方法:

def _get_root_node(tree: DecisionTreeClassificationModel):
    return tree._call_java('rootNode')

杂质

我们可以通过从根节点向下走到树上来获得杂质。其pre-order transversal可以这样完成:

def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
    def recur(node):
        if node.numDescendants() == 0:
            return []
        ni = node.impurity()
        return (
            recur(node.leftChild()) + [ni] + recur(node.rightChild())
        )
    return recur(_get_root_node(tree))

示例

In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
  If (feature 0 <= 6.5)
   If (feature 0 <= 3.5)
    Predict: 1.0
   Else (feature 0 > 3.5)
    If (feature 0 <= 5.0)
     Predict: 0.0
    Else (feature 0 > 5.0)
     Predict: 1.0
  Else (feature 0 > 6.5)
   Predict: 0.0


In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]