我正在尝试从Skealrn中的决策树生成所有路径。
这里的estimator
来自随机森林,它是sklearn中的决策树。
但是我对sklearn决策树的数据结构感到困惑。似乎left
,right
这里包含了所有的左节点。
当我尝试打印路径时,它可以正常工作。
def get_code(tree, feature_names, tabdepth=0):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, tabdepth=0):
if (threshold[node] != -2):
print ('\t' * tabdepth + "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node], tabdepth+1)
print ('\t' * tabdepth + "} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node], tabdepth+1)
print ('\t' * tabdepth + "}")
else:
print ('\t' * tabdepth + "return " + str(value[node]))
recurse(left, right, threshold, features, 0)
但是我需要收集列表中的所有路径,如果叶子节点为“正常”,也就不记录路径,所以我尝试了以下代码:
def extract_attack_rules(estimator, feature_names):
left = estimator.tree_.children_left
right = estimator.tree_.children_right
threshold = estimator.tree_.threshold
features = [feature_names[i] for i in estimator.tree_.feature]
value = estimator.tree_.value
def recurse(left, right, threshold, features, node):
path_lst = []
if threshold[node] != -2: # not leaf node
left_cond = features[node]+"<="+str(threshold[node])
right_cond = features[node]+">"+str(threshold[node])
if left[node] != -1: # not leaf node
left_path_lst = recurse(left, right, threshold, features,left[node])
if right[node] != -1: # not leaf node
right_path_lst = recurse(left, right, threshold, features,right[node])
if left_path_lst is not None:
path_lst.extend([left_path.append(left_cond) for left_path in left_path_lst])
if pre_right_path is not None:
path_lst.extend([right_path.append(right_cond) for right_path in right_path_lst])
return path_lst
else: # leaf node, the attack type
if value[node][0][0] > 0: # if leaf is normal, not collect this path
return None
else: # attack
for i in range(len(value[node][0])):
if value[node][0][i] > 0:
return [[value[node][0][i]]]
all_path = recurse(left, right, threshold, features, 0)
return all_path
它返回一个超级巨无霸的结果,那就是没有足够的内存要加载,我敢肯定这里的代码有问题,因为所有需要的路径都不应该那么大。我也在这里尝试了方法:Getting decision path to a node in sklearn,但是sklearn树结构的输出只会让我更加困惑。
您知道如何在这里解决问题吗?