我正在尝试从CHAID输出中绘制树形图。有很多示例,但它们似乎仅适用于二进制拆分。我正在生成的树是CHAID树,具有超过1个拆分。
我尝试了一些解决方案,但它们始终将输出显示为文本,而不是生成树图:
from CHAID import Tree
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
import random
from io import StringIO
def to_graphviz(self, filename=None, shape='circle', graph='digraph'):
"""Exports the tree in the dot format of the graphviz software"""
nodes, connections = [], []
if self.nodes:
for n in self.expand_tree(mode=self.WIDTH):
nid = self[n].identifier
state = '"{0}" [label="{1}", shape={2}]'.format(
nid, self[n].tag, shape)
nodes.append(state)
for c in self.children(nid):
cid = c.identifier
connections.append('"{0}" -> "{1}"'.format(nid, cid))
# write nodes and connections to dot format
is_plain_file = filename is not None
if is_plain_file:
f = codecs.open(filename, 'w', 'utf-8')
else:
f = StringIO()
f.write(graph + ' tree {\n')
for n in nodes:
f.write('\t' + n + '\n')
if len(connections) > 0:
f.write('\n')
for c in connections:
f.write('\t' + c + '\n')
f.write('}')
if not is_plain_file:
print(f.getvalue())
f.close()
data = ["a", "b", "c", "d"]
df=pd.DataFrame(([random.sample(data, 4) for _ in range(500)]))
df.columns=[['a','b','c','d']]
## set the CHAID input parameters
independent_variable_columns = ['a', 'b', 'c']
dep_variable = 'd'
tree = Tree.from_pandas_df(df, dict(zip(independent_variable_columns, ['nominal'] * 3)), dep_variable,min_child_node_size=0)
tree.print_tree()
这产生了:
([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))
|-- (['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))
| |-- (['b'], {'a': 0, 'b': 0, 'c': 16.0, 'd': 20.0}, <Invalid Chaid Split> - the max depth has been reached)
| |-- (['c'], {'a': 0, 'b': 13.0, 'c': 0, 'd': 19.0}, <Invalid Chaid Split> - the max depth has been reached)
| +-- (['d'], {'a': 0, 'b': 18.0, 'c': 18.0, 'd': 0}, <Invalid Chaid Split> - the max depth has been reached)
|-- (['b'], {'a': 41.0, 'b': 0, 'c': 51.0, 'd': 41.0}, (('a',), p=2.2336212592454823e-14, score=70.03309116568674, groups=[['a'], ['c'], ['d']]), dof=4)) etc
与显示带有框和标签等的图相反。
我也尝试过:
treeG=tree.to_tree()
display(SVG(to_graphviz(treeG).pipe(format='svg')))
但这会产生大致相同的结果:
digraph tree {
"0" [label="([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))", shape=circle]
"1" [label="(['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))", shape=circle] etc..
如果有人知道如何将其转换为正确的图表,将不胜感激。
谢谢+ BR