我正在执行下面的代码并在其后出现错误 -
from IPython.display import Image
from sklearn.tree import export_graphviz
from six import StringIO
import pydotplus
features = list(df.columns[:-1])
features
dot_data = StringIO()
export_graphviz(dtree, out_file=dot_data,feature_names=features,filled=True,rounded=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
错误-
InvocationException Traceback (most recent call last)
<ipython-input-98-1978b4285d97> in <module>
3
4 graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
----> 5 Image(graph.create_png())
~\anaconda3\lib\site-packages\pydotplus\graphviz.py in <lambda>(f, prog)
1789 self.__setattr__(
1790 'create_' + frmt,
-> 1791 lambda f=frmt, prog=self.prog: self.create(format=f, prog=prog)
1792 )
1793 f = self.__dict__['create_' + frmt]
~\anaconda3\lib\site-packages\pydotplus\graphviz.py in create(self, prog, format)
2022
2023 if status != 0:
-> 2024 raise InvocationException(
2025 'Program terminated with status: %d. stderr follows: %s' % (
2026 status, stderr_output))
InvocationException: Program terminated with status: 1. stderr follows: 'C:\Users\Ankit' is not recognized as an internal or external command,
operable program or batch file.
我假设这是因为我的用户名 - Ankit Chawrai。 虽然请告诉我可能的解决方案是什么。
答案 0 :(得分:0)
我正在尝试一个非常相似的示例,它基于机器学习操作手册,该手册使用台湾信用卡数据集预测违约风险。我的设置如下:
from six import StringIO
from sklearn.tree import export_graphviz
from IPython.display import Image
import pydotplus
然后创建决策树图是这样完成的:
dot_data = StringIO()
export_graphviz(decision_tree=class_tree,
out_file=dot_data,
filled=True,
rounded=True,
feature_names = X_train.columns,
class_names = ['pay','default'],
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
我认为这一切都来自 out_file=dot_data
参数,但无法确定文件路径的创建和存储位置,因为 print(dot_data.getvalue())
没有显示任何路径名。
在我的研究中,我遇到了 sklearn.plot_tree(),它似乎可以完成 graphviz 所做的一切。所以我采用了上面的 exporet_graphviz 参数,并且匹配的参数在我添加的 .plot_tree 方法中。
我最终得到了与文本中相同的图像:
from sklearn import tree
plt.figure(figsize=(20, 10))
tree.plot_tree(class_tree,
filled=True, rounded=True,
feature_names = X_train.columns,
class_names = ['pay','default'],
fontsize=12)
plt.show()