我在scikit-learn中通过RandomForestClassifier
使用随机森林,并希望检查生成的决策树的节点上的决策标准。我可以看到一种在拟合后通过DecisionTreeClassifier
)访问特定RandomForestClassifier.ensemble_[i]
的方法,我可以看到一种导出树以生成graphviz图像的方法(通过sklearn.tree.export_graphviz()
)。但是我看不到以任何方式描述树的方法比图像更简单 - 特别是我只想要一种人类可读的文本格式。
具体来说:graphviz树图像在每个节点上包含描述该节点的决策标准和结果的文本。我想要的是能够生成这个每节点文本,以及哪些节点是哪些节点的子节点的规范,但就像文本一样 - 没有嵌入到图像或点文件中。从技术上讲,点文件是文本,但它是为渲染图像而设计的,如果您想要的只是了解树,则难以阅读。 scikit-learn中是否有任何导出函数可以产生某种人类可读的DecisionTreeClassifier
描述?
我的后备计划是通过修改sklearn.tree.export_graphviz()
来源来编写我自己的函数,但我想知道是否已存在某些内容。
答案 0 :(得分:0)
将其发布为答案,因为我无法发表评论。
此link用于scikit学习方法sklearn.tree.export_text
,该方法应根据我从页面获取的以下代码块输出文本:
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree.export import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
>>> print(r)
|--- petal width (cm) <= 0.80
| |--- class: 0
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2
...
我尝试使用它的缺点,但这会发生:
from sklearn.tree.export import export_text
Traceback (most recent call last):
File "<input>", line 1, in <module>
ImportError: cannot import name 'export_text'
也许您可以获得比我更好的结果,如果可以,请告诉我。我正在使用scikit-learn == 0.20.3