在学习了mllib DecisionTree模型(http://spark.apache.org/docs/latest/mllib-decision-tree.html)后,如何计算节点统计信息,例如支持(与该子树匹配的样本数)以及每个标签与该子树匹配的样本数量?
如果它更容易,我也很乐意采用除Spark之外的任何其他工具来获取调试字符串并计算这些统计信息。调试字符串示例:
DecisionTreeModel classifier of depth 20 with 20031 nodes
If (feature 0 <= -35.0)
If (feature 24 <= 176.0)
If (feature 0 <= -200.0)
If (feature 29 <= 109.0)
If (feature 6 <= -156.0)
If (feature 9 <= 0.0)
If (feature 20 <= -116.0)
If (feature 16 <= 203.0)
If (feature 11 <= 163.0)
If (feature 5 <= 384.0)
If (feature 15 <= 325.0)
If (feature 13 <= -248.0)
If (feature 20 <= -146.0)
Predict: 0.0
Else (feature 20 > -146.0)
If (feature 19 <= -58.0)
Predict: 6.0
Else (feature 19 > -58.0)
Predict: 0.0
Else (feature 13 > -248.0)
If (feature 9 <= -26.0)
Predict: 0.0
Else (feature 9 > -26.0)
If (feature 10 <= 218.0)
...
我正在使用mllib因为核心学习,我需要它,因为数据不适合内存。如果你有比mllib更好的选择,我很乐意尝试一下。
答案 0 :(得分:0)
我使用sklearn作为算法来创建我的模型,并与Spark Context集成以产生这样的输出:
if ( device_type_id <= 1 )
39 Clicks - 0.61%
2135 Conversions - 33.32%
else ( device_type_id > 1 )
if ( country_id <= 216 )
1097 Clicks - 17.12%
else ( country_id > 216 )
if ( browser_id <= 2 )
296 Clicks - 4.62%
else ( browser_id > 2 )
if ( browser_id <= 4 )
if ( browser_id <= 3 )
if ( operating_system_id <= 2 )
262 Clicks - 4.09%
以下是我用来显示这样一棵树的代码:
def get_code(count_df, tree, feature_names, target_names, spacer_base=" "):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
temp_list = []
res_count = count_df
def recurse(res_count, temp_list, left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
temp_list.append("if ( " + features[node] + " <= " + \
str(int(round(threshold[node] - 1))) + " )")
if left[node] != -1:
recurse (res_count, temp_list, left, right, threshold, features, left[node], depth+1)
temp_list.append("else ( " + features[node] + " > " + \
str(int(round(threshold[node] - 1))) + " )")
if right[node] != -1:
recurse (res_count, temp_list, left, right, threshold, features, right[node], depth+1)
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
temp_list.append(str(target_count) +" "+ str(target_name) + " - " + str(round((target_count / res_count), 4) * 100)+ "%")
recurse(res_count, temp_list, left, right, threshold, features, 0, 0)
return temp_list
否则,请参阅我的帖子here中提供的答案,但是它写在Scala
中,改变了Spark
生成决策树的方式。