将scikit学习决策树解析为JSON

时间:2020-05-17 10:37:31

标签: python json scikit-learn

我有这段代码可以将决策树从scikit_learn转换为JSON。

    def treeToJson(decision_tree, feature_names=None):
      from warnings import warn

      js = ""

      def node_to_str(tree, node_id, criterion):
        if not isinstance(criterion, sklearn.tree.tree.six.string_types):
          criterion = "impurity"

        value = tree.value[node_id]
        if tree.n_outputs == 1:
          value = value[0, :]

        jsonValue = ', '.join([str(x) for x in value])

        if tree.children_left[node_id] == sklearn.tree._tree.TREE_LEAF:
          return '"id": "%s", "criterion": "%s", "impurity": "%s", "samples": "%s", "value": [%s]' \
                 % (node_id, 
                    criterion,
                    tree.impurity[node_id],
                    tree.n_node_samples[node_id],
                    jsonValue)
        else:
          if feature_names is not None:
            feature = feature_names[tree.feature[node_id]]
          else:
            feature = tree.feature[node_id]

          if "=" in feature:
            ruleType = "="
            ruleValue = "false"
          else:
            ruleType = "<="
            ruleValue = "%.4f" % tree.threshold[node_id]

          return '"id": "%s", "rule": "%s %s %s", "%s": "%s", "samples": "%s"' \
                 % (node_id, 
                    feature,
                    ruleType,
                    ruleValue,
                    criterion,
                    tree.impurity[node_id],
                    tree.n_node_samples[node_id])

      def recurse(tree, node_id, criterion, parent=None, depth=0):
        tabs = "  " * depth
        js = ""

        left_child = tree.children_left[node_id]
        right_child = tree.children_right[node_id]

        js = js + "\n" + \
             tabs + "{\n" + \
             tabs + "  " + node_to_str(tree, node_id, criterion)

        if left_child != sklearn.tree._tree.TREE_LEAF:
          js = js + ",\n" + \
               tabs + '  "left": ' + \
               recurse(tree, \
                       left_child, \
                       criterion=criterion, \
                       parent=node_id, \
                       depth=depth + 1) + ",\n" + \
               tabs + '  "right": ' + \
               recurse(tree, \
                       right_child, \
                       criterion=criterion, \
                       parent=node_id,
                       depth=depth + 1)

        js = js + tabs + "\n" + \
             tabs + "}"

        return js

      if isinstance(decision_tree, sklearn.tree.tree.Tree):
        js = js + recurse(decision_tree, 0, criterion="impurity")
      else:
        js = js + recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)

      return js

当我叫它时:

treetojson = treeToJson(clf, feature_cols)

我收到此错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-56-d9b330d3ed27> in <module>
----> 1 treetojson = treeToJson(clf, feature_cols)

<ipython-input-55-7e42fe63109a> in treeToJson(decision_tree, feature_names)
     74     return js
     75 
---> 76   if isinstance(decision_tree, sklearn.tree.tree.Tree):
     77     js = js + recurse(decision_tree, 0, criterion="impurity")
     78   else:

AttributeError: module 'sklearn.tree' has no attribute 'tree'

1 个答案:

答案 0 :(得分:0)

我正在检查分类决策树,它是BaseDecisionTree的子类。因此您可以更改

if isinstance(decision_tree, sklearn.tree.tree.Tree)

if isinstance(decision_tree, sklearn.tree.BaseDecisionTree)