我从scikit-learn了解决策树的结果有两个问题。例如,这是我的决策树之一:
我的问题是我如何使用树?
第一个问题是:如果样本满足条件,则进入 LEFT 分支(如果存在),否则进入 RIGHT 。在我的情况下,如果样本有X [7]> 63521.3984。然后样品将进入绿色框。正确的吗?
第二个问题是:当样本到达叶节点时,我怎么知道它属于哪个类别?在这个例子中,我有三个类别要分类。在红色框中,分别满足91,212和113个样本的条件。但是我该如何确定类别呢? 我知道有一个函数 clf.predict(样本)来告诉该类别。我可以从图表中做到吗??? 非常感谢。
答案 0 :(得分:26)
每个框中的value
行告诉您该节点上有多少样本按顺序分配到每个类别中。这就是为什么在每个框中,value
中的数字加起来为sample
中显示的数字。例如,在您的红色框中,91 + 212 + 113 = 416。所以这意味着如果你到达这个节点,那么类别1中有91个数据点,类别2中有212个数据点,类别3中有113个数据点。
如果您要预测在决策树中到达该叶子的新数据点的结果,您可以预测类别2,因为这是该节点上样本的最常见类别。
答案 1 :(得分:5)
第一个问题: 是的,你的逻辑是正确的。左侧节点为True,右侧节点为False。这可能是违反直觉的; true可以等同于较小的样本。
第二个问题: 通过使用pydotplus将树可视化为图形,可以最好地解决此问题。 tree.export_graphviz()的'class_names'属性将向每个节点的多数类添加一个类声明。代码在iPython笔记本中执行。
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf2 = tree.DecisionTreeClassifier()
clf2 = clf2.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
import os
os.unlink('iris.dot')
import pydotplus
dot_data = tree.export_graphviz(clf2, out_file=None)
graph2 = pydotplus.graph_from_dot_data(dot_data)
graph2.write_pdf("iris.pdf")
from IPython.display import Image
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, # leaves_parallel=True,
special_characters=True)
graph2 = pydotplus.graph_from_dot_data(dot_data)
## Color of nodes
nodes = graph2.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
values = color[values.index(max(values))]; # print(values)
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
node.set_fillcolor(color )
#
Image(graph2.create_png() )
至于确定叶子上的类,你的例子没有一个类的叶子,就像虹膜数据集一样。这很常见,可能需要过度拟合模型才能获得这样的结果。对于许多交叉验证的模型,离散的类分布是最佳结果。
享受代码!
答案 2 :(得分:1)
根据“学习scikit-learn:Python中的机器学习”一书,决策树代表了一系列基于训练数据的决策。
!(http://i.imgur.com/vM9fJLy.png)
要对实例进行分类,我们应该回答每个节点的问题。例如,性别<= 0.5? (我们在谈论一个女人吗?) 如果答案是肯定的,则转到树中的左子节点;否则你会转到正确的子节点。你一直在回答问题(她是在第三堂课吗?,她是在头等舱吗?,她是13岁以下吗?),直到你到达一片叶子。 当你在那里时,预测对应于具有大多数实例的目标类。
答案 3 :(得分:0)
将feature_names = X.columns添加到tree.export_graphviz,其中X是训练数据。
我的代码如下
with open("lectureGini.txt", "w") as f:
f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/
lectureGini是我的DecisionTreeClassifier的输出
这是我发现的一种简单方法,可以将其添加到我研究过的所有Gini索引的网络示例中。所有的网络示例都很好地解释了该方法,但是没有一个示例显示如何找到类别。 我尚未安装Graphviz,所以我要从jupyter导出文本文件并将其复制到Webgraphwiz