如何从scikit-learn解释决策树

时间:2014-05-09 05:58:02

标签: python numpy scipy scikit-learn decision-tree

我从scikit-learn了解决策树的结果有两个问题。例如,这是我的决策树之一:

enter image description here 我的问题是我如何使用树?

第一个问题是:如果样本满足条件,则进入 LEFT 分支(如果存在),否则进入 RIGHT 。在我的情况下,如果样本有X [7]> 63521.3984。然后样品将进入绿色框。正确的吗?

第二个问题是:当样本到达叶节点时,我怎么知道它属于哪个类别?在这个例子中,我有三个类别要分类。在红色框中,分别满足91,212和113个样本的条件。但是我该如何确定类别呢? 我知道有一个函数 clf.predict(样本)来告诉该类别。我可以从图表中做到吗??? 非常感谢。

4 个答案:

答案 0 :(得分:26)

每个框中的value行告诉您该节点上有多少样本按顺序分配到每个类别中。这就是为什么在每个框中,value中的数字加起来为sample中显示的数字。例如,在您的红色框中,91 + 212 + 113 = 416。所以这意味着如果你到达这个节点,那么类别1中有91个数据点,类别2中有212个数据点,类别3中有113个数据点。

如果您要预测在决策树中到达该叶子的新数据点的结果,您可以预测类别2,因为这是该节点上样本的最常见类别。

答案 1 :(得分:5)

第一个问题: 是的,你的逻辑是正确的。左侧节点为True,右侧节点为False。这可能是违反直觉的; true可以等同于较小的样本。

第二个问题: 通过使用pydotplus将树可视化为图形,可以最好地解决此问题。 tree.export_graphviz()的'class_names'属性将向每个节点的多数类添加一个类声明。代码在iPython笔记本中执行。

from sklearn.datasets import load_iris  
from sklearn import tree  
iris = load_iris()  
clf2 = tree.DecisionTreeClassifier()  
clf2 = clf2.fit(iris.data, iris.target)  

with open("iris.dot", 'w') as f:  
    f = tree.export_graphviz(clf, out_file=f)  

import os  
os.unlink('iris.dot')  

import pydotplus  
dot_data = tree.export_graphviz(clf2, out_file=None)  
graph2 = pydotplus.graph_from_dot_data(dot_data)  
graph2.write_pdf("iris.pdf")  

from IPython.display import Image  
dot_data = tree.export_graphviz(clf2, out_file=None,  
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  # leaves_parallel=True, 
                     special_characters=True)  
graph2 = pydotplus.graph_from_dot_data(dot_data)

## Color of nodes
nodes = graph2.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
        color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
        values = color[values.index(max(values))]; # print(values)
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
        node.set_fillcolor(color )
#

Image(graph2.create_png() ) 

enter image description here

至于确定叶子上的类,你的例子没有一个类的叶子,就像虹膜数据集一样。这很常见,可能需要过度拟合模型才能获得这样的结果。对于许多交叉验证的模型,离散的类分布是最佳结果。

享受代码!

答案 2 :(得分:1)

根据“学习scikit-learn:Python中的机器学习”一书,决策树代表了一系列基于训练数据的决策。

!(http://i.imgur.com/vM9fJLy.png

要对实例进行分类,我们应该回答每个节点的问题。例如,性别<= 0.5? (我们在谈论一个女人吗?) 如果答案是肯定的,则转到树中的左子节点;否则你会转到正确的子节点。你一直在回答问题(她是在第三堂课吗?,她是在头等舱吗?,她是13岁以下吗?),直到你到达一片叶子。 当你在那里时,预测对应于具有大多数实例的目标类

答案 3 :(得分:0)

将feature_names = X.columns添加到tree.export_graphviz,其中X是训练数据。

我的代码如下

with open("lectureGini.txt", "w") as f:
    f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/

lectureGini是我的DecisionTreeClassifier的输出

这是我发现的一种简单方法,可以将其添加到我研究过的所有Gini索引的网络示例中。所有的网络示例都很好地解释了该方法,但是没有一个示例显示如何找到类别。 我尚未安装Graphviz,所以我要从jupyter导出文本文件并将其复制到Webgraphwiz

相关问题