如何从CHAID树输出中绘制树图

时间:2019-08-29 13:30:40

标签: python pandas tree graphviz

我正在尝试从CHAID输出中绘制树形图。有很多示例,但它们似乎仅适用于二进制拆分。我正在生成的树是CHAID树,具有超过1个拆分。

我尝试了一些解决方案,但它们始终将输出显示为文本,而不是生成树图:

from CHAID import Tree
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
import random
from io import StringIO
def to_graphviz(self, filename=None, shape='circle', graph='digraph'):
        """Exports the tree in the dot format of the graphviz software"""
        nodes, connections = [], []
        if self.nodes:

            for n in self.expand_tree(mode=self.WIDTH):
                nid = self[n].identifier
                state = '"{0}" [label="{1}", shape={2}]'.format(
                    nid, self[n].tag, shape)
                nodes.append(state)

                for c in self.children(nid):
                    cid = c.identifier
                    connections.append('"{0}" -> "{1}"'.format(nid, cid))

        # write nodes and connections to dot format
        is_plain_file = filename is not None
        if is_plain_file:
            f = codecs.open(filename, 'w', 'utf-8')
        else:
            f = StringIO()

        f.write(graph + ' tree {\n')
        for n in nodes:
            f.write('\t' + n + '\n')

        if len(connections) > 0:
            f.write('\n')

        for c in connections:
            f.write('\t' + c + '\n')

        f.write('}')

        if not is_plain_file:
            print(f.getvalue())

        f.close()
data = ["a", "b", "c", "d"]
df=pd.DataFrame(([random.sample(data, 4) for _ in range(500)]))
df.columns=[['a','b','c','d']]
## set the CHAID input parameters
independent_variable_columns = ['a', 'b', 'c']
dep_variable = 'd'
tree = Tree.from_pandas_df(df, dict(zip(independent_variable_columns, ['nominal'] * 3)), dep_variable,min_child_node_size=0)
tree.print_tree()

这产生了:

([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))
|-- (['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))
|   |-- (['b'], {'a': 0, 'b': 0, 'c': 16.0, 'd': 20.0}, <Invalid Chaid Split> - the max depth has been reached)
|   |-- (['c'], {'a': 0, 'b': 13.0, 'c': 0, 'd': 19.0}, <Invalid Chaid Split> - the max depth has been reached)
|   +-- (['d'], {'a': 0, 'b': 18.0, 'c': 18.0, 'd': 0}, <Invalid Chaid Split> - the max depth has been reached)
|-- (['b'], {'a': 41.0, 'b': 0, 'c': 51.0, 'd': 41.0}, (('a',), p=2.2336212592454823e-14, score=70.03309116568674, groups=[['a'], ['c'], ['d']]), dof=4)) etc

与显示带有框和标签等的图相反。

我也尝试过:

treeG=tree.to_tree()
display(SVG(to_graphviz(treeG).pipe(format='svg')))

但这会产生大致相同的结果:

digraph tree {
    "0" [label="([], {'a': 130.0, 'b': 113.0, 'c': 134.0, 'd': 123.0}, (('b',), p=1.888741762227604e-33, score=177.262344476939, groups=[['a'], ['b'], ['c'], ['d']]), dof=9))", shape=circle]
    "1" [label="(['a'], {'a': 0, 'b': 31.0, 'c': 34.0, 'd': 39.0}, (('c',), p=8.919582934418456e-11, score=52.9052990371776, groups=[['b'], ['c'], ['d']]), dof=4))", shape=circle] etc..

如果有人知道如何将其转换为正确的图表,将不胜感激。

谢谢+ BR

0 个答案:

没有答案