我创建了一个名为model的树分类器,并尝试使用export graphviz函数,如下所示:
export_graphviz(decision_tree=model,
out_file='NT_model.dot',
feature_names=X_train.columns,
class_names=model.classes_,
leaves_parallel=True,
filled=True,
rotate=False,
rounded=True)
出于某种原因,我的运行引发了这个异常:
TypeError Traceback (most recent call last) <ipython-input-298-40fe56bb0c85> in <module>() 6 filled=True, 7 rotate=False, ----> 8 rounded=True) C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- packages\sklearn\tree\export.py in export_graphviz(decision_tree, out_file, max_depth, feature_names, class_names, label, filled, leaves_parallel, impurity, node_ids, proportion, rotate, rounded, special_characters) 431 recurse(decision_tree, 0, criterion="impurity") 432 else: --> 433 recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion) 434 435 # If required, draw leaf nodes at same depth as each other C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- packages\sklearn\tree\export.py in recurse(tree, node_id, criterion, parent, depth) 319 out_file.write('%d [label=%s' 320 % (node_id, --> 321 node_to_str(tree, node_id, criterion))) 322 323 if filled: C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- packages\sklearn\tree\export.py in node_to_str(tree, node_id, criterion) 289 np.argmax(value), 290 characters[2]) --> 291 node_string += class_name 292 293 # Clean up any trailing newlines TypeError: ufunc 'add' did not contain a loop with signature matching types dtype('<U90') dtype('<U90') dtype('<U90')
可视化的超级参数是:
print(model)
DecisionTreeClassifier(class_weight={1.0: 10, 0.0: 1}, criterion='gini',
max_depth=7, max_features=None, max_leaf_nodes=None,
min_impurity_split=1e-07, min_samples_leaf=50,
min_samples_split=2, min_weight_fraction_leaf=0.0,
presort=False, random_state=0, splitter='best')
print(model.classes_)
[ 0. , 1. ]
非常感谢帮助!
答案 0 :(得分:1)
正如您在documentation of export_graphviz中指出的那样,param class_names
适用于字符串,而不适用于float或int。
class_names:字符串列表,bool或None,可选(默认=无)
尝试将model.classes_
转换为字符串列表,然后再将其传递给export_graphviz。
在class_names=['0', '1']
的调用中尝试class_names=['0.0', '1.0']
或export_graphviz()
。
要获得更通用的解决方案,请使用:
class_names=[str(x) for x in model.classes_]
但是有一个特定的原因是你在y
中将浮动值传递为model.fit()
吗?因为在分类任务中大多不需要这样做。您是否仅使用实际的y
标签,或者在拟合模型之前将字符串标签转换为数字?