为决策树添加正确的标签

时间:2019-05-01 07:50:32

标签: scikit-learn decision-tree

我正在机器学习项目中使用随机森林回归。为了更好地理解预测的逻辑,我想可视化一些决策树,并检查何时使用哪些功能。

为此,我编写了以下代码:

from sklearn.tree import export_graphviz
from subprocess import call
from IPython.display import Image

# Select one estimator from the Random Forests
estimator = best_estimators_regr['RandomForestRegressor'][0].estimators_[0]

export_graphviz(estimator, out_file=path+'tree.dot', 
           rounded=True, proportion=False, 
           precision=2, filled=True)
call(['dot', '-Tpng', path+'tree.dot', '-o', path+'tree.png', '-Gdpi=600'])
Image(filename=path+'tree.png')

问题是训练模型时我使用了max_features参数,所以我不知道每棵树中使用了哪些功能。因此,在绘制树时,我只得到X[some_number]。这个数字对应于原始数据集中的列吗?如果没有,我如何告诉它使用列名而不是数字?

1 个答案:

答案 0 :(得分:1)

'max_features'中的RandomForestClassifier参数用于一次获取要素数量以找到最佳分割。该参数将传递给所有单独的估算器(DecisionTreeClassifier)。基本的DecisionTreeClassifier对象都接受整个数据(其中从训练数据中采样样本,但是所有列特征都传递给每棵树)。功能排序确定为单个DecisionTreeClassifier对象。因此,不必为此担心。

您可以只使用feature_names中的export_graphviz参数来传递所有功能的每个功能的名称。