这是我第一次尝试在graphviz上绘制树,下面是我的代码:
# import Python package to generate tree
from sklearn import tree
from sklearn.tree import export_graphviz
from sfi import Data
from sfi import Macro
from IPython.display import Image
from PIL import Image
import io
import matplotlib.pyplot as plt
import pydotplus
import pandas as pd
import os
# change working directory
os.chdir("/Users/hyunjindominiquecho/Desktop")
# Use the sfi Data class to pull data from Stata variables
# Note here, that I am using the Python-Stata interface to
# import the Stata dataset onto Python via "Data.get"
X = pd.DataFrame(Data.get("speaker1 speaker2 sex1 sex2 sp"),
columns = ['speaker1', 'speaker2', 'sex1', 'sex2', 'sp'])
# drop 'sp' from the original dataframe X for data analysis
X_no_sp = X.drop(columns=['sp'], axis=1)
y = pd.DataFrame(Data.get("type sp"), columns = ['type','sp'])
# get number of observations in pharmacy_small dataset
nobs = X.shape[0]
# seperate the transfered variables into the appropriate train and test set
# based on the value of 'sp' (to be consistent with the split that was done on the Stata).
# generating a python copy of the split sample is necessary,
# since we are going to perform the naive bayes on python, which will require pandas dataframe as an input.
X_train = X.loc[X['sp'] == 1]
X_test = X.loc[X['sp'] == 2]
y_train = y.loc[y['sp'] == 1]
y_test = y.loc[y['sp'] == 2]
# get rid of 'sp' from each pandas dataframe, since 'sp' shouldn't be included in our analysis
del X_train['sp']
del X_test['sp']
del y_train['sp']
del y_test['sp']
# create tree objects
model_gini_class = tree.DecisionTreeClassifier(criterion='gini')
# train the model using the training sets and check score
model_gini_class.fit(X_train, y_train)
model_gini_class.score(X_train, y_train)
# predict output
# note here that we are making predictions for all of the observations in X
# (not just X_test)
# this is so that I can more easily import the prediction as Stata variable,
# as Stata variables have to have their length equal to the total num of observations.
predicted_gini_class = model_gini_class.predict(X_no_sp)
# plot the two trees (Entropy and Gini)
dot_data = io.StringIO()
tree.export_graphviz(model_gini_class,
out_file=dot_data,
rounded = True,
filled = True,
feature_names=X_no_sp.columns)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png("treeGini.png")
但是我对此图有问题,因为从上面的树图中,如果您从顶部看第二行,您会看到在左侧的绿色矩形中指定了sex2
为sex2 <= 1.5
。由于sex2
是一个指标变量,对于女性来说等于1,对于男性来说等于2,所以我希望树显示类似sex2 <= male
而不是sex2 <= 1.5
的东西。
我该如何实现?谢谢