树分类器到graphviz ERROR

时间:2017-08-13 13:36:14

标签: python-3.x scikit-learn

我创建了一个名为model的树分类器,并尝试使用export graphviz函数,如下所示:

export_graphviz(decision_tree=model,
                    out_file='NT_model.dot',
                    feature_names=X_train.columns,
                    class_names=model.classes_,
                    leaves_parallel=True,
                    filled=True,
                    rotate=False,
                    rounded=True)

出于某种原因,我的运行引发了这个异常:

  
TypeError                                 Traceback (most recent call last)
<ipython-input-298-40fe56bb0c85> in <module>()
      6                     filled=True,
      7                     rotate=False,
----> 8                     rounded=True)

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site-
packages\sklearn\tree\export.py in export_graphviz(decision_tree, out_file, 
max_depth, feature_names, class_names, label, filled, leaves_parallel, 
impurity, node_ids, proportion, rotate, rounded, special_characters)
    431             recurse(decision_tree, 0, criterion="impurity")
    432         else:
--> 433             recurse(decision_tree.tree_, 0, 
criterion=decision_tree.criterion)
    434
    435         # If required, draw leaf nodes at same depth as each other

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site-
packages\sklearn\tree\export.py in recurse(tree, node_id, criterion, parent, 
depth)
    319             out_file.write('%d [label=%s'
    320                            % (node_id,
--> 321                               node_to_str(tree, node_id, 
criterion)))
    322
    323             if filled:

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site-
packages\sklearn\tree\export.py in node_to_str(tree, node_id, criterion)
    289                                           np.argmax(value),
    290                                           characters[2])
--> 291             node_string += class_name
    292
    293         # Clean up any trailing newlines

TypeError: ufunc 'add' did not contain a loop with signature matching types 
dtype('<U90') dtype('<U90') dtype('<U90')

可视化的超级参数是:

print(model)
DecisionTreeClassifier(class_weight={1.0: 10, 0.0: 1}, criterion='gini',
        max_depth=7, max_features=None, max_leaf_nodes=None,
        min_impurity_split=1e-07, min_samples_leaf=50,
        min_samples_split=2, min_weight_fraction_leaf=0.0,
        presort=False, random_state=0, splitter='best')

print(model.classes_)
[ 0. , 1. ]

非常感谢帮助!

1 个答案:

答案 0 :(得分:1)

正如您在documentation of export_graphviz中指出的那样,param class_names适用于字符串,而不适用于float或int。

  

class_names:字符串列表,bool或None,可选(默认=无)

尝试将model.classes_转换为字符串列表,然后再将其传递给export_graphviz。

class_names=['0', '1']的调用中尝试class_names=['0.0', '1.0']export_graphviz()

要获得更通用的解决方案,请使用:

class_names=[str(x) for x in model.classes_]

但是有一个特定的原因是你在y中将浮动值传递为model.fit()吗?因为在分类任务中大多不需要这样做。您是否仅使用实际的y标签,或者在拟合模型之前将字符串标签转换为数字?