如何在python中提取随机森林的决策规则?

时间:2019-05-14 08:49:57

标签: python-3.x

我正在从随机森林中提取决策规则,并且已阅读参考链接:

how extraction decision rules of random forest in python

此代码输出为:

<input type="text" id="textbox" hidden />

if ($(this).parent().text().trim() === "Other")
   {
       if ($('#textbox').is(':hidden')) {
           $('#textbox').removeAttr("hidden");
        }
        else {
             $('#textbox').attr("hidden", "hidden");
             }
    }

,但这不是理想的输出。这不是规则,只是打印树。

理想的输出是:

TREE: 0
0 NODE: if feature[33] < 2.5 then next=1 else next=4
1 NODE: if feature[38] < 0.5 then next=2 else next=3
2 LEAF: return class=2
3 LEAF: return class=9
4 NODE: if feature[50] < 8.5 then next=5 else next=6
5 LEAF: return class=4
6 LEAF: return class=0
...

我不知道如何生成理想的输出。期待您的帮助!

2 个答案:

答案 0 :(得分:1)

这是根据您的要求的解决方案。 这将为您提供每个基础学习者使用的决策规则(即sklearn的RandomForestClassifier中n_estimator中使用的值将不使用DecisionTree。)

Serial_Raw  Total_Remaining
-----------------------------
1           10
2           0
3           20
4           100

我从这里获得了决策规则代码 How to extract the decision rules from scikit-learn decision-tree?

from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree

#Decision Rules to code utility
def dtree_to_code(tree, feature_names, tree_idx):
        """
        Decision tree rules in the form of Code.
        """
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        print('def tree_{1}({0}):'.format(", ".join(feature_names),tree_idx))

        def recurse(node, depth):
            indent = "  " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                print ('{0}if {1} <= {2}:'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                print ('{0}else:  # if {1} > {2}'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                print ('{0}return {1}'.format(indent, tree_.value[node]))
        recurse(0, 1)
def rf_to_code(rf,feature_names):
    """
    Conversion of Random forest Decision rules to code.
    """
    for base_learner_id, base_learner in enumerate(rf.estimators_):
        dtree_to_code(tree = base_learner,feature_names=feature_names,tree_idx=base_learner_id)

如果一切顺利,则输出:

#clf : RandomForestClassifier(n_estimator=100)
#df :  Iris Dataframe

rf_to_code(rf=clf,feature_names=df.columns)

由于 n_estimators = 100 ,您将总共获得100个此类函数。

答案 1 :(得分:0)

基于另一个答案...交叉兼容并且只使用一个变量 X。

from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree

#Decision Rules to code utility
def dtree_to_code(fout,tree, variables, feature_names, tree_idx):
        """
        Decision tree rules in the form of Code.
        """
        f = fout
        tree_ = tree.tree_
        feature_name = [
            variables[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        if tree_idx<=0:
            f.write('def predict(X):\n\tret = 0\n')

        def recurse(node, depth):
            indent = "\t" * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                variable = variables[node]
                name = feature_names[node]
                threshold = tree_.threshold[node]
                f.write('%sif %s <= %s: # if %s <= %s\n'%(indent, variable, threshold, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                f.write ('%selse:  # if %s > %s\n'%(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                yhat = np.argmax(tree_.value[node][0])
                if yhat!=0:
                    f.write("%sret += %s\n"%(indent, yhat))
                else:
                    f.write("%spass\n"%(indent))
        recurse(0, 1)
def rf_to_code(f,rf,variables,feature_names):
    """
    Conversion of Random forest Decision rules to code.
    """
    for base_learner_id, base_learner in enumerate(rf.estimators_):
        dtree_to_code(f, tree=base_learner, variables=variables, feature_names=feature_names, tree_idx=base_learner_id)
    f.write('\treturn ret/%s\n'%(base_learner_id+1))

with open('_model.py', 'w') as f:
    f.write('''
from numba import jit,njit
@njit\n''')
    labels = ['w_%s'%word for word in d_q2i.keys()]
    variables = ['X[%s]'%i for i,word in enumerate(d_q2i.keys())]
    rf_to_code(f,estimator,variables,labels)  

输出看起来像这样。 X 是表示单个实例特征的一维向量。

from numba import jit,njit
@njit
def predict(X):
    ret = 0
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                ret += 1
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                ret += 1
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    return ret/10