Spark mllib DecisionTree

时间:2016-06-07 09:40:58

标签: apache-spark pyspark decision-tree apache-spark-mllib

在学习了mllib DecisionTree模型(http://spark.apache.org/docs/latest/mllib-decision-tree.html)后,如何计算节点统计信息,例如支持(与该子树匹配的样本数)以及每个标签与该子树匹配的样本数量?

如果它更容易,我也很乐意采用除Spark之外的任何其他工具来获取调试字符串并计算这些统计信息。调试字符串示例:

DecisionTreeModel classifier of depth 20 with 20031 nodes
  If (feature 0 <= -35.0)
   If (feature 24 <= 176.0)
    If (feature 0 <= -200.0)
     If (feature 29 <= 109.0)
      If (feature 6 <= -156.0)
       If (feature 9 <= 0.0)
        If (feature 20 <= -116.0)
         If (feature 16 <= 203.0)
          If (feature 11 <= 163.0)
           If (feature 5 <= 384.0)
            If (feature 15 <= 325.0)
             If (feature 13 <= -248.0)
              If (feature 20 <= -146.0)
               Predict: 0.0
              Else (feature 20 > -146.0)
               If (feature 19 <= -58.0)
                Predict: 6.0
               Else (feature 19 > -58.0)
                Predict: 0.0
             Else (feature 13 > -248.0)
              If (feature 9 <= -26.0)
               Predict: 0.0
              Else (feature 9 > -26.0)
               If (feature 10 <= 218.0)
...

我正在使用mllib因为核心学习,我需要它,因为数据不适合内存。如果你有比mllib更好的选择,我很乐意尝试一下。

1 个答案:

答案 0 :(得分:0)

我使用sklearn作为算法来创建我的模型,并与Spark Context集成以产生这样的输出:

if ( device_type_id <= 1 )
    39 Clicks - 0.61%
    2135 Conversions - 33.32% 
else ( device_type_id > 1 )
    if ( country_id <= 216 )
        1097 Clicks - 17.12%
    else ( country_id > 216 )
        if ( browser_id <= 2 )
            296 Clicks - 4.62%
        else ( browser_id > 2 )
            if ( browser_id <= 4 )
                if ( browser_id <= 3 )
                    if ( operating_system_id <= 2 )
                        262 Clicks - 4.09%

以下是我用来显示这样一棵树的代码:

def get_code(count_df, tree, feature_names, target_names, spacer_base="    "):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value
    temp_list = []
    res_count = count_df
    def recurse(res_count, temp_list, left, right, threshold, features, node, depth):
        spacer = spacer_base * depth
        if (threshold[node] != -2):
            temp_list.append("if ( " + features[node] + " <= " + \
                str(int(round(threshold[node] - 1))) + " )")
            if left[node] != -1:
                    recurse (res_count, temp_list, left, right, threshold, features, left[node], depth+1)
            temp_list.append("else ( " + features[node] + " > " + \
                str(int(round(threshold[node] - 1))) + " )")
            if right[node] != -1:
                    recurse (res_count, temp_list, left, right, threshold, features, right[node], depth+1)

        else:
            target = value[node]
            for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
                target_name = target_names[i]
                target_count = int(v)
                temp_list.append(str(target_count) +" "+ str(target_name) + " - " + str(round((target_count / res_count), 4) * 100)+ "%")

    recurse(res_count, temp_list, left, right, threshold, features, 0, 0)
    return temp_list

否则,请参阅我的帖子here中提供的答案,但是它写在Scala中,改变了Spark生成决策树的方式。