我已经训练了DecisionTreeClassifier来预测二进制目标。然后我将树可视化如下:
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
dt = DecisionTreeClassifier(min_samples_split=20,max_depth=5,criterion="entropy",random_state=99)
dt.fit(df[important_features], target)
from sklearn.tree import export_graphviz
from sklearn.tree import _tree
import graphviz
export_graphviz(dt, out_file="tree.dot",
feature_names=important_features)
with open("tree.dot") as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
在树中,我看到最后一级有类似的信息,例如:
entropy = 0.9908
samples = 67465
value = [37538, 29927]
但是我找不到应该是二进制变量的target
。我希望在树的最后一层看到它,从而分析0
和1
的规则/路径。我做错了什么?