我有这段代码可以将决策树从scikit_learn
转换为JSON。
def treeToJson(decision_tree, feature_names=None):
from warnings import warn
js = ""
def node_to_str(tree, node_id, criterion):
if not isinstance(criterion, sklearn.tree.tree.six.string_types):
criterion = "impurity"
value = tree.value[node_id]
if tree.n_outputs == 1:
value = value[0, :]
jsonValue = ', '.join([str(x) for x in value])
if tree.children_left[node_id] == sklearn.tree._tree.TREE_LEAF:
return '"id": "%s", "criterion": "%s", "impurity": "%s", "samples": "%s", "value": [%s]' \
% (node_id,
criterion,
tree.impurity[node_id],
tree.n_node_samples[node_id],
jsonValue)
else:
if feature_names is not None:
feature = feature_names[tree.feature[node_id]]
else:
feature = tree.feature[node_id]
if "=" in feature:
ruleType = "="
ruleValue = "false"
else:
ruleType = "<="
ruleValue = "%.4f" % tree.threshold[node_id]
return '"id": "%s", "rule": "%s %s %s", "%s": "%s", "samples": "%s"' \
% (node_id,
feature,
ruleType,
ruleValue,
criterion,
tree.impurity[node_id],
tree.n_node_samples[node_id])
def recurse(tree, node_id, criterion, parent=None, depth=0):
tabs = " " * depth
js = ""
left_child = tree.children_left[node_id]
right_child = tree.children_right[node_id]
js = js + "\n" + \
tabs + "{\n" + \
tabs + " " + node_to_str(tree, node_id, criterion)
if left_child != sklearn.tree._tree.TREE_LEAF:
js = js + ",\n" + \
tabs + ' "left": ' + \
recurse(tree, \
left_child, \
criterion=criterion, \
parent=node_id, \
depth=depth + 1) + ",\n" + \
tabs + ' "right": ' + \
recurse(tree, \
right_child, \
criterion=criterion, \
parent=node_id,
depth=depth + 1)
js = js + tabs + "\n" + \
tabs + "}"
return js
if isinstance(decision_tree, sklearn.tree.tree.Tree):
js = js + recurse(decision_tree, 0, criterion="impurity")
else:
js = js + recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
return js
当我叫它时:
treetojson = treeToJson(clf, feature_cols)
我收到此错误:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-56-d9b330d3ed27> in <module>
----> 1 treetojson = treeToJson(clf, feature_cols)
<ipython-input-55-7e42fe63109a> in treeToJson(decision_tree, feature_names)
74 return js
75
---> 76 if isinstance(decision_tree, sklearn.tree.tree.Tree):
77 js = js + recurse(decision_tree, 0, criterion="impurity")
78 else:
AttributeError: module 'sklearn.tree' has no attribute 'tree'
答案 0 :(得分:0)
我正在检查分类决策树,它是BaseDecisionTree的子类。因此您可以更改
if isinstance(decision_tree, sklearn.tree.tree.Tree)
到
if isinstance(decision_tree, sklearn.tree.BaseDecisionTree)