scikit学习决策树导出graphviz - 决策树中的错误类名

时间:2016-12-18 11:33:17

标签: scikit-learn graphviz decision-tree

我在决策树中得到了错误的类名,来自" scikit learn / decision tree / export graphviz"。该计划如下所示:

import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree

digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']

digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)

with open("digital.dot", 'w') as f:
    f = tree.export_graphviz(digital_tree, 
                            feature_names=digital_name,
                            class_names=digital_label,
                            filled=True, rounded=True,
                            out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")

plt.imshow(img.imread('digital.png'))
plt.show()

输出如下:

the decision tree

问题在于叶子中显示的类名。例如,绿色框应标记为'三个'如果idx-1同时为1而idx-2为1.但是,图像将标签显示为' one'。有人可以发表你的意见吗?

2 个答案:

答案 0 :(得分:2)

当您使用DecisionTreeClassifier时,您应该将类​​标签更改为数字,如0,1,2

然后使用:

classe_names = decision_tree_classifier.classes_

它将按升序为您提供班级的标签。然后以相同的顺序指定class_label。它可以是字符串。

答案 1 :(得分:0)

在将类标签传递给export_graphviz

之前,请尝试按字母顺序对其进行排序