如何理解scikit-learn中DecisionTreeClassifier的输出?

时间:2013-09-12 01:00:11

标签: scikit-learn

我正在学习ML并使用scikit-learn进行基本决策树分类。

功能的值是分类的,所以我使用DictVectorizer来转换原始功能值。这是我的代码:

training_set # list of dict representing the traing set
labels # corresponding labels of the training set
vec = DictVectorizer()
vectorized = vec.fit_transform(training_set)
clf = tree.DecisionTreeClassifier()
clf.fit(vectorized.toarray(), labels)

with open("output.dot", "w") as output_file:
    tree.export_graphviz(clf, out_file=output_file)

但我不理解输出图。它包含一个树,每个节点标记为X[1] <= 0.5000或类似的东西。我所期望的是标有FEATURE_1 == VALUE_1的节点,un-vectorized信息显示在树上。

有可能吗?

更新:

例如,FEATURE_1有三个可能的值ABC,后者又被矢量化为0,00,11,0分别。我想要的图表是FEATURE_1 == A而不是X[1] <= 0.5

enter image description here

4 个答案:

答案 0 :(得分:9)

您可以将要素名称传递给树导出方法:

with open("output.dot", "w") as output_file:
    tree.export_graphviz(clf, feature_names=vec.get_feature_names(),
                         out_file=output_file)

分类器本身并不知道数据的“含义”,它只处理连续的数值,因此需要使用矢量化器对分类变量进行热编码,将二进制变量安全地视为连续变量。 [0, 1]范围内的变量,所有实际值都是0或1,两者之间没有任何内容。

要了解DictVectorizer如何进行热门编码,请查看文档中的example snippet

答案 1 :(得分:1)

如果您有二进制变量,

X[1] <= 0.5000表示X[1] = 0。如果等式成立,则选择左分支。否则,右分支。你当然可以解析点文件并覆盖它(它只是一个文本文件,并且很容易用正则表达式),但它最初构造的方式是这样修复的,因为默认情况下树的节点是不等式。

答案 2 :(得分:0)

当值在连续间隔中时,机器学习者将对值进行排序并查找所有中间值以查找具有最高基尼指数的值。

这是合理的,因为在连续域中,找到具有精确值的测试实例的可能性,比方说,3.1415为零。在这种情况下,分类器不应该知道该怎么做。

我不知道scikit-learn,但在WEKA中,可以指定值是继续还是离散。

答案 3 :(得分:0)

当你执行export_graphviz时,指定feature_names,在这种情况下为自变量DataFrame指定列名。

这会产生输出文件中的列名,如下所示。

model = clf.fit(X, y)

dot_data = tree.export_graphviz(model, out_file=None, feature_names=X.columns.values.tolist(), class_names = None, filled=True, rounded=True, special_characters=True)

with open("output.dot", "w") as output_file:
    output_file.write(dot_data)