我怎样才能在Python中绘制CART树,就像在R中一样?

时间:2014-04-02 22:38:39

标签: python r graph tree scikit-learn

在R中,我可以直接使用API​​绘制与CART模型对应的决策树的图形表示。例如,prp将产生类似

的内容

但我找不到任何类似的API用于Python中的等效功能。例如,我尽可能地告诉sklearn' s RandomForestClassifierDecisionTreeClassifier都没有方法或绘图树。

如何在Python中获得CART或随机林树的图形表示?

3 个答案:

答案 0 :(得分:5)

使用export_graphviz功能。

from sklearn.tree import DecisionTreeClassifier, export_graphviz
np.random.seed(0)
X = np.random.randn(10, 4)
y = array(["foo", "bar", "baz"])[np.random.randint(0, 3, 10)]
clf = DecisionTreeClassifier(random_state=42).fit(X, y)
export_graphviz(clf)

现在dotty tree.dot应该显示类似

的内容

tree visualization

这是notebook

答案 1 :(得分:1)

除了此处列出的其他方法之外,从scikit-learn 21.0版开始(大约在2019年5月),现在可以使用scikit-learn的tree.plot_tree和matplotlib来绘制决策树,而无需依赖graphviz。

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

X, y = load_iris(return_X_y=True)

# Make an instance of the Model
clf = DecisionTreeClassifier()

# Train the model on the data
clf.fit(X, y)

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']

# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

tree.plot_tree(clf,
           feature_names = fn, 
           class_names=cn,
           filled = True);

fig.savefig('imagename.png')

下面的图像是保存的图像。 enter image description here

此代码改编自post

答案 2 :(得分:0)

此功能将使图形显示在Jupyter笔记本中:

# Imports
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image, display
import pydotplus

def jupyter_graphviz(m, **kwargs):
    dot_data = StringIO()
    export_graphviz(m, dot_data, **kwargs)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
    display(Image(graph.create_png()))

例如:

import sklearn.datasets as datasets
import pandas as pd

iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
dtree = DecisionTreeClassifier(random_state=42)
dtree.fit(df, y)

jupyter_graphviz(dtree, filled=True, rounded=True, special_characters=True)

Tree visualization

这是notebook的改版中的this post