我想通过Spark MLlib的决策树获得有关生成模型的每个节点的更多详细信息。我可以使用API获得的最接近的是print(model.toDebugString())
,返回类似这样的内容(取自PySpark doc)
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
如何修改MLlib源代码以获得每个节点的杂质和深度? (如有必要,我如何在PySpark中调用新的Scala函数?)
答案 0 :(得分:1)
不幸的是,我找不到任何方法在PySpark或Spark(Scala API)中直接访问节点 。但是有一种方法可以从根节点开始并遍历到不同的节点。
(我刚刚在这里提到了杂质,但是对于深度,我可以轻松替换impurity
与subtreeDepth
。)
假设决策树模型实例为dt
:
root = dt.call("topNode")
root.impurity() # gives the impurity of the root node
现在,如果我们查看适用于root
的方法:
dir(root)
[u'apply', u'deepCopy', u'emptyNode', u'equals', 'getClass', u'getNode', u'hashCode', u'id', 'impurity', u'impurity_$eq', u'indexToLevel', u'initializeLogIfNecessary', u'isLeaf', u'isLeaf_$eq', u'isLeftChild', u'isTraceEnabled', u'leftChildIndex', u'leftNode', u'leftNode_$eq', u'log', u'logDebug', u'logError', u'logInfo', u'logName', u'logTrace', u'logWarning', u'maxNodesInLevel', u'notify', u'notifyAll', u'numDescendants', u'org$apache$spark$internal$Logging$$log_', u'org$apache$spark$internal$Logging$$log__$eq', u'parentIndex', u'predict', u'predict_$eq', u'rightChildIndex', u'rightNode', u'rightNode_$eq', u'split', u'split_$eq', u'startIndexInLevel', u'stats', u'stats_$eq', u'subtreeDepth', u'subtreeIterator', u'subtreeToString', u'subtreeToString$default$1', u'toString', u'wait']
我们可以做到:
root.leftNode().get().impurity()
这可以在树中更深入,例如:
root.leftNode().get().rightNode().get().impurity()
自应用leftNode()
或rightNode()
后,我们转到option
,应用get
或getOrElse is necessary to get to the desired
节点类型。
如果你想知道我如何使用这些奇怪的方法,我必须承认,我作弊!!,即我首先研究了Scala API:
以下几行与上述内容完全等效,假设dt
相同,则给出相同的结果:
val root = dt.topNode
root.impurity
我们可以做到:
root.leftNode.get.impurity
这可以在树中更深入,例如:
root.leftNode.get.rightNode.get.impurity
答案 1 :(得分:0)
我将通过描述我如何使用PySpark 2.4.3来完成@mostOfMajority的回答。
给出训练有素的决策树模型,这就是获取其根节点的方法:
def _get_root_node(tree: DecisionTreeClassificationModel):
return tree._call_java('rootNode')
我们可以通过从根节点向下走到树上来获得杂质。其pre-order transversal可以这样完成:
def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
def recur(node):
if node.numDescendants() == 0:
return []
ni = node.impurity()
return (
recur(node.leftChild()) + [ni] + recur(node.rightChild())
)
return recur(_get_root_node(tree))
In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
If (feature 0 <= 6.5)
If (feature 0 <= 3.5)
Predict: 1.0
Else (feature 0 > 3.5)
If (feature 0 <= 5.0)
Predict: 0.0
Else (feature 0 > 5.0)
Predict: 1.0
Else (feature 0 > 6.5)
Predict: 0.0
In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]