我可以以文本格式查看或导出决策树的节点吗?

时间:2017-11-21 01:15:14

标签: python scikit-learn

我在scikit-learn中通过RandomForestClassifier使用随机森林,并希望检查生成的决策树的节点上的决策标准。我可以看到一种在拟合后通过DecisionTreeClassifier)访问特定RandomForestClassifier.ensemble_[i]的方法,我可以看到一种导出树以生成graphviz图像的方法(通过sklearn.tree.export_graphviz())。但是我看不到以任何方式描述树的方法比图像更简单 - 特别是我只想要一种人类可读的文本格式。

具体来说:graphviz树图像在每个节点上包含描述该节点的决策标准和结果的文本。我想要的是能够生成这个每节点文本,以及哪些节点是哪些节点的子节点的规范,但就像文本一样 - 没有嵌入到图像或点文件中。从技术上讲,点文件是文本,但它是为渲染图像而设计的,如果您想要的只是了解树,则难以阅读。 scikit-learn中是否有任何导出函数可以产生某种人类可读的DecisionTreeClassifier描述?

我的后备计划是通过修改sklearn.tree.export_graphviz()来源来编写我自己的函数,但我想知道是否已存在某些内容。

1 个答案:

答案 0 :(得分:0)

将其发布为答案,因为我无法发表评论。

link用于scikit学习方法sklearn.tree.export_text,该方法应根据我从页面获取的以下代码块输出文本:

>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree.export import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
>>> print(r)
|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2
...

我尝试使用它的缺点,但这会发生:

from sklearn.tree.export import export_text
Traceback (most recent call last):
  File "<input>", line 1, in <module>
ImportError: cannot import name 'export_text'

也许您可以获得比我更好的结果,如果可以,请告诉我。我正在使用scikit-learn == 0.20.3