我正在从随机森林中提取决策规则,并且已阅读参考链接:
how extraction decision rules of random forest in python
此代码输出为:
<input type="text" id="textbox" hidden />
if ($(this).parent().text().trim() === "Other")
{
if ($('#textbox').is(':hidden')) {
$('#textbox').removeAttr("hidden");
}
else {
$('#textbox').attr("hidden", "hidden");
}
}
,但这不是理想的输出。这不是规则,只是打印树。
理想的输出是:
TREE: 0
0 NODE: if feature[33] < 2.5 then next=1 else next=4
1 NODE: if feature[38] < 0.5 then next=2 else next=3
2 LEAF: return class=2
3 LEAF: return class=9
4 NODE: if feature[50] < 8.5 then next=5 else next=6
5 LEAF: return class=4
6 LEAF: return class=0
...
我不知道如何生成理想的输出。期待您的帮助!
答案 0 :(得分:1)
这是根据您的要求的解决方案。 这将为您提供每个基础学习者使用的决策规则(即sklearn的RandomForestClassifier中n_estimator中使用的值将不使用DecisionTree。)
Serial_Raw Total_Remaining
-----------------------------
1 10
2 0
3 20
4 100
我从这里获得了决策规则代码 How to extract the decision rules from scikit-learn decision-tree??
from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree
#Decision Rules to code utility
def dtree_to_code(tree, feature_names, tree_idx):
"""
Decision tree rules in the form of Code.
"""
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print('def tree_{1}({0}):'.format(", ".join(feature_names),tree_idx))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print ('{0}if {1} <= {2}:'.format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
print ('{0}else: # if {1} > {2}'.format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
print ('{0}return {1}'.format(indent, tree_.value[node]))
recurse(0, 1)
def rf_to_code(rf,feature_names):
"""
Conversion of Random forest Decision rules to code.
"""
for base_learner_id, base_learner in enumerate(rf.estimators_):
dtree_to_code(tree = base_learner,feature_names=feature_names,tree_idx=base_learner_id)
如果一切顺利,则输出:
#clf : RandomForestClassifier(n_estimator=100)
#df : Iris Dataframe
rf_to_code(rf=clf,feature_names=df.columns)
由于 n_estimators = 100 ,您将总共获得100个此类函数。
答案 1 :(得分:0)
基于另一个答案...交叉兼容并且只使用一个变量 X。
from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree
#Decision Rules to code utility
def dtree_to_code(fout,tree, variables, feature_names, tree_idx):
"""
Decision tree rules in the form of Code.
"""
f = fout
tree_ = tree.tree_
feature_name = [
variables[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
if tree_idx<=0:
f.write('def predict(X):\n\tret = 0\n')
def recurse(node, depth):
indent = "\t" * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
variable = variables[node]
name = feature_names[node]
threshold = tree_.threshold[node]
f.write('%sif %s <= %s: # if %s <= %s\n'%(indent, variable, threshold, name, threshold))
recurse(tree_.children_left[node], depth + 1)
f.write ('%selse: # if %s > %s\n'%(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
yhat = np.argmax(tree_.value[node][0])
if yhat!=0:
f.write("%sret += %s\n"%(indent, yhat))
else:
f.write("%spass\n"%(indent))
recurse(0, 1)
def rf_to_code(f,rf,variables,feature_names):
"""
Conversion of Random forest Decision rules to code.
"""
for base_learner_id, base_learner in enumerate(rf.estimators_):
dtree_to_code(f, tree=base_learner, variables=variables, feature_names=feature_names, tree_idx=base_learner_id)
f.write('\treturn ret/%s\n'%(base_learner_id+1))
with open('_model.py', 'w') as f:
f.write('''
from numba import jit,njit
@njit\n''')
labels = ['w_%s'%word for word in d_q2i.keys()]
variables = ['X[%s]'%i for i,word in enumerate(d_q2i.keys())]
rf_to_code(f,estimator,variables,labels)
输出看起来像这样。 X 是表示单个实例特征的一维向量。
from numba import jit,njit
@njit
def predict(X):
ret = 0
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
ret += 1
else: # if w_pizza > 0.5
pass
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
ret += 1
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
ret += 1
else: # if w_mexico > 0.5
ret += 1
else: # if w_pizza > 0.5
pass
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
ret += 1
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
ret += 1
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
pass
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
ret += 1
else: # if w_pizza > 0.5
ret += 1
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
ret += 1
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
pass
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
pass
if X[0] <= 0.5: # if w_pizza <= 0.5
if X[1] <= 0.5: # if w_mexico <= 0.5
if X[2] <= 0.5: # if w_reusable <= 0.5
ret += 1
else: # if w_reusable > 0.5
pass
else: # if w_mexico > 0.5
pass
else: # if w_pizza > 0.5
pass
return ret/10