帮助编写python中的决策树

时间:2015-08-18 03:36:43

标签: machine-learning python

我不确定这是否是发布此内容的正确位置,但我一直试图编写一个简单的决策树类一段时间,并在各个点迷失。

具体来说,我不确定哪种数据结构会代表使用(feature,value)作为节点的递归树。

class DecisionTree():

def entropy(self, data):

    # if there's nothing in this region, entropy is 1
    if len(data) <= 1:
        return 1

    target_col = data.ix[:,-1]
    size = float(len(target_col))
    classes = Counter(target_col)

    # if there's only one class, entropy is 1
    if len(classes) == 1:
        return 1

    else:
        probs = [i / size for i in classes.values()]
        entropy = np.sum([-probs[i]*np.log(probs[i]) for i in range(len(probs))])

    return entropy

def what_to_split_on(self, data):

    split_feature = -1
    best_entropy = 0.0
    base_entropy = self.entropy(data)

    for f, feature in enumerate(data.T):

        unique_vals = list(set(feature))
        for val in unique_vals:

            left, right = self.split(f, val)
            prop_left = float(len(left)) / (len(left) + len(right))
            prop_right = 1 - prop_left

            e_1 = prop_left * self.entropy(left)
            e_2 = prop_right * self.entropy(right)

            entropy_change = base_entropy - e_1 - e_2
            if entropy_change > best_entropy:

                best_entropy = entropy_change
                split_feature = f; split_val = val

    if split_feature != -1:
        return split_feature, split_val

def split(self, data, f, val):

    left = np.array([row for row in data if row[f] == val])
    right = np.array([row for row in data if row[f] != val])

    return left, right

def create_tree(self, data):

    if self.entropy(data) == 1:
        return

    feature, value = self.what_to_split_on(data)

    dt = Tree(feature, value)

    left_child = np.array([row for row in data if row[feature] == value])
    right_child = np.array([row for row in data if row[feature] == value])

    feature, value = self.what_to_split_on(left_child)
    sub_left = create_tree(left_child)
    dt.insert_left(sub_left)

    feature, value = self.what_to_split_on(right_child)
    sub_right = create_tree(right_child)
    dt.insert_right(sub_right)

    return dt

0 个答案:

没有答案