使用scikit-learn时,如何找到树分裂的属性?

时间:2013-11-23 00:29:51

标签: python machine-learning scikit-learn decision-tree

我一直在探索scikit-learn,使用熵和基尼分裂标准制定决策树,并探索差异。

我的问题是,我怎样才能“打开引擎盖”并确切地找出每个级别树木分裂的属性及其相关的信息值,这样我就能看出这两个标准在哪里做出不同的选择? / p>

到目前为止,我已经探讨了文档中概述的9种方法。它们似乎不允许访问此信息。但是这些信息肯定是可以访问的吗?我正在设想一个包含节点和增益条目的列表或字典。

感谢您的帮助和道歉,如果我错过了一些非常明显的事情。

2 个答案:

答案 0 :(得分:30)

直接来自文档(http://scikit-learn.org/0.12/modules/tree.html):

from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
  Python3不再支持

StringIO模块,而是导入io模块。

决策树对象中还有tree_属性,允许直接访问整个结构。

你可以简单地阅读它

clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values

有关详细信息,请查看source code of export method

通常,您可以使用inspect模块

from inspect import getmembers
print( getmembers( clf.tree_ ) )

获取所有对象的元素

Decision tree visualization from sklearn docs

答案 1 :(得分:2)

Scikit Learn在0.21版(2019年5月)中引入了一种名为export_text的美味新方法,可从树中查看所有规则。 Documentation here

一旦适合模型,您只需要两行代码。首先,导入export_text

from sklearn.tree.export import export_text

第二,创建一个包含规则的对象。要使规则看起来更具可读性,请使用feature_names参数并传递功能名称列表。例如,如果您的模型名为model,并且要素在名为X_train的数据框中命名,则可以创建一个名为tree_rules的对象:

tree_rules = export_text(model, feature_names=list(X_train))

然后仅打印或保存tree_rules。您的输出将如下所示:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1