无法从头开始实现描述树的分支(LONG)

时间:2017-04-03 17:21:10

标签: python recursion machine-learning decision-tree

我目前正在使用Python从头开始实现Decision Tree算法。 我在实现树的分支时遇到了麻烦。在当前的实现中,我没有使用Depth参数。

发生的事情是,分支结束得太快(如果我使用标志来阻止无限递归),或者如果我删除了标志,我会遇到无限递归。如果我在主循环或递归循环中,我也无法理解。

我的数据非常简单:

d = {'one' : [1., 2., 3., 4.],
 'two' : [4., 3., 2., 1.]}

df = pd.DataFrame(d)

df['three'] = (0,0,1,1)
df = np.array(df)

导致输出:

array([[ 1.,  4.,  0.],
       [ 2.,  3.,  0.],
       [ 3.,  2.,  1.],
       [ 4.,  1.,  1.]])

我将使用gini_index进行拆分。这个功能不是解决我问题的必要条件,所以我将把它放在这个问题的最后,以帮助重现。

我正在使用字典对象y,它将在分支扩展时继续包含嵌套字典。

                              y
                          /         \
                   y['left']               y['right']    
                  /       \                         \
   y['left']['left']   y['left']['right']              y['right']  ['right']

接下来,我将分解创建树的功能,我遇到了一些问题。

def create_tree2(node, flag ):   #node is a dictionary containing the root, which will contain nested dictionaries as this function recursively calls itself.

    left, right =node['Groups']  # ['Groups'] is a key contains that contains two groups which will be used for the next split; I'm assigning them to left and right here
    left,right = np.array(left), np.array(right)  #just converting them to array because my other functions rely on the data to be in array format. 


    print ('left_group', left)    #these are for debugging purposes. 
    print('right_group', right)

if flag == True and (right.size ==0 or left.size ==0):   
    node['left'] = left
    node['right'] = right
    flag = False
    return 

#This above portion is to prevent infinite loops.

关于无限递归,发生了什么,如果我有两行数据,而不是将两行分成两个不同的节点, 我得到一个没有行的节点,另一个节点有两行。

如果一个节点中的数据行少于两行,则我的循环通常会停止。 所以空节点会终止, 但是具有两行数据的节点将再次拆分为空节点和两行填充节点。这个过程将永远持续下去。  所以我尝试使用一个标志来防止这种无限循环。 旗帜的唯一问题是,它似乎提前一步激活。, 它不会检查拆分是否会导致两个节点或无限循环。例如:

A split leads to 
left = []
right =     [ [ 3.,  2.,  1.],
       [ 4.,  1.,  1.]])]

now instead of checking if the right can split further
 (left =[3,2,1] , right =  [ 4.,  1.,  1.]), 

旗帜在上面的步骤停止,太早了一步。

if len(left) < 2:
    node['left'] =left
    return

#Here I'm ending the node, if the len is less than 2 rows of data. 




else:

    node['left'] = check_split(left)
    print('after left split', node['left']['Groups'])# for debugging purposes
    create_tree2(node['left'], True)


#This is splitting the data and then recursively calling the create_tree2 function
#given that len of the group is NOT less than two. 
#And the flag gets activated to  prevent infinite looping. 
#Notice that node['left'] is being used as the node parameter in the recursion function.





if len(right) <2:
    node['right'] = right
    return
else:
    node['right'] = check_split(right)
    print('right_check_split')
    create_tree(node['right'],False)



#doing the same thing with the right side. 

这里唯一的问题(我假设)是如果左侧先递归调用自己, 然后节点参数将更改为节点['left']字典 以及左右局部变量 使用左侧分支信息进行更新。

让我们看一下输出
以下是代码在被调用后的外观:

#first split

left_group [[ 1.  4.  0.]
 [ 2.  3.  0.]]

right_group [[ 3.  2.  1.]
[ 4.  1.  1.]]






# first the left_group calls itself recursively producing an additional split
resulting in
a new left group that  is empty, and a right_group has two rows

left_group []

right_group [[ 1.  4.  0.]
[ 2.  3.  0.]]


# now  the `if` flag statement gets called
 `if flag == True and (right.size ==0 or left.size ==0):   
        node['left'] = left
        node['right'] = right
        flag = False
        return `


    #ideally I want to do one more  split on the right group,
 to see if right group   would split further but didn't know how to implement that properly. I'm assuming I would need some sort of counter? 


#Next it jumps to the right main branch correctly. 
not sure how as `right` was updated after    the left's recursive function


right_check_split

left_group []
right_group [[ 3.  2.  1.]
[ 4.  1.  1.]]


This also activates the flag which stops the iteration. Ideally I would like this to go at least one more round to check if the right group [3,2,1] and [4,1,1] would split into two branches. Not sure how to do that?  

我很困惑的另一件事是字典为什么能够在正确的主节点中启动,而不是左边的嵌套字典。

回想一下,递归首先发生在主左分支

create_tree2(node['left'] , True), 

这应该更新左边和右边的值,当我们点击这部分函数时它会继续:

if len(right) <2:
    node['right'] = right
    return
else:
    node['right'] = check_split(right)  #This right value would have been updated on?
    print('right_check_split')
    create_tree(node['right'],False)

所以我担心正确的值会改为  [[ 1. 4. 0.] [ 2. 3. 0.]]但它记住了根节点的原始正确值,即

right_group [[ 3. 2. 1.] [ 4. 1. 1.]].

所以我的问题是

1)如何正确实现标志以检查以确保在启动if flag loop

之前确实存在无限递归

2)尽管递归函数使用左分支值更新参数,但我的函数能够使用先前正确的值(这是我想要的)并且能够在适当的位置正确创建新的嵌套字典。

如果需要,可以填写完整的代码

import numpy as np
import pandas as pd


d = {'one' : [1., 2., 3., 4.],
 'two' : [4., 3., 2., 1.]}

df = pd.DataFrame(d)

df['three'] = (0,0,1,1)
df = np.array(df)




def split_method(data, index, value):
    left, right  = list(), list()
    for row in data:
    #for i in range((data.shape[-1] -1)):
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)

return left, right





def gini(data,groups ):
    data_size = len(data)


    gini_index = 0

    for group in groups:
        group_size = len(group)
        multiplier = float(group_size/data_size)
        prob =1
        if group_size == 0:
                continue
        print('multiplier', multiplier)
        for value in set(data[:,-1]):
            prob*= [row[-1] for row in group].count(value)/group_size
            print ('prob', prob)
        gini_index +=  (multiplier * prob)

    return gini_index




def check_split(data):
    main_score = 999
    gini_index = 999
    gini_value = 999
    print('data', data)
    for index in range(len(data[0])-1):
        for rows in data:
            value = rows[index]
            groups =split_method(data, index, value)
            gini_score =gini(data,groups)

            if gini_score < main_score:
                main_score = gini_score
                gini_index, gini_value, gini_groups = index, value,np.array(groups)


    return {'Index': gini_index, 'Value': gini_value, 'Groups': gini_groups}



def create_tree2(node, flag ):

    left, right =node['Groups']
    left,right = np.array(left), np.array(right)
    print ('left_group', left)
    print('right_group', right)

    if flag == True and (right.size ==0 or left.size ==0):
        node['left'] = left
        node['right'] = right
        flag = False
        return 

    if len(left) < 2:
        node['left'] =left
        return

    else:

        node['left'] = check_split(left)
        print('after left split', node['left']['Groups'])
        create_tree2(node['left'],flag = True)

    if len(right) <2:
        node['right'] = right
        return
    else:
        node['right'] = check_split(right)
        print('right_check_split')
        create_tree2(node['right'],flag =True)




    return    node



root = check_split(df)   # this creates the root dictionary, (first dictionary)
y = create_tree2(root, False)

1 个答案:

答案 0 :(得分:1)

我对您的功能进行了以下更改:

def create_tree2(node, flag=False):

    left, right =node['Groups']
    left, right = np.array(left), np.array(right)
    print('left_group', left)
    print('right_group', right)

    if flag == True and (right.size ==0 or left.size ==0):
        node['left'] = left
        node['right'] = right
        flag = False
        return

    if len(left) < 2:
        node['left'] = left
        flag = True
        print('too-small left. flag=True')
    else:
        node['left'] = check_split(left)
        print('after left split', node['left']['Groups'])
        create_tree2(node['left'],flag)

    if len(right) < 2:
        node['right'] = right
        print('too-small right. flag=True')
        flag = True
    else:
        node['right'] = check_split(right)
        print('after right split', node['right']['Groups'])
        create_tree2(node['right'], flag)

    return    node


d = {'one' : [1., 2., 3., 4.],
 'two' : [4., 3., 2., 1.]}

df = pd.DataFrame(d)

df['three'] = (0,0,1,1)
df = np.array(df)

root = check_split(df)
y = create_tree2(root)

基本上,我使用len<2检查将标志设置为True,然后允许右侧递归。我仍然认为这是对的,因为len == 1可能会发生一些事情。但是没有无限的递归。

我得到了这个输出:

left_group [[ 1.  4.  0.]
 [ 2.  3.  0.]]
right_group [[ 3.  2.  1.]
 [ 4.  1.  1.]]
after left split [array([], shape=(0, 3), dtype=float64)
 array([[ 1.,  4.,  0.],
       [ 2.,  3.,  0.]])]
left_group []
right_group [[ 1.  4.  0.]
 [ 2.  3.  0.]]
too-small left. flag=True
after right split [array([], shape=(0, 3), dtype=float64)
 array([[ 1.,  4.,  0.],
       [ 2.,  3.,  0.]])]
left_group []
right_group [[ 1.  4.  0.]
 [ 2.  3.  0.]]
after right split [array([], shape=(0, 3), dtype=float64)
 array([[ 3.,  2.,  1.],
       [ 4.,  1.,  1.]])]
left_group []
right_group [[ 3.  2.  1.]
 [ 4.  1.  1.]]
too-small left. flag=True
after right split [array([], shape=(0, 3), dtype=float64)
 array([[ 3.,  2.,  1.],
       [ 4.,  1.,  1.]])]
left_group []
right_group [[ 3.  2.  1.]
 [ 4.  1.  1.]]
Y= {'Groups': array([[[ 1.,  4.,  0.],
        [ 2.,  3.,  0.]],

       [[ 3.,  2.,  1.],
        [ 4.,  1.,  1.]]]), 'Index': 0, 'right': {'Groups': array([array([], shape=(0, 3), dtype=float64),
       array([[ 3.,  2.,  1.],
       [ 4.,  1.,  1.]])], dtype=object), 'Index': 0, 'right': {'Groups': array([array([], shape=(0, 3), dtype=float64),
       array([[ 3.,  2.,  1.],
       [ 4.,  1.,  1.]])], dtype=object), 'Index': 0, 'right': array([[ 3.,  2.,  1.],
       [ 4.,  1.,  1.]]), 'Value': 3.0, 'left': array([], shape=(0, 3), dtype=float64)}, 'Value': 3.0, 'left': array([], shape=(0, 3), dtype=float64)}, 'Value': 3.0, 'left': {'Groups': array([array([], shape=(0, 3), dtype=float64),
       array([[ 1.,  4.,  0.],
       [ 2.,  3.,  0.]])], dtype=object), 'Index': 0, 'right': {'Groups': array([array([], shape=(0, 3), dtype=float64),
       array([[ 1.,  4.,  0.],
       [ 2.,  3.,  0.]])], dtype=object), 'Index': 0, 'right': array([[ 1.,  4.,  0.],
       [ 2.,  3.,  0.]]), 'Value': 1.0, 'left': array([], shape=(0, 3), dtype=float64)}, 'Value': 1.0, 'left': array([], shape=(0, 3), dtype=float64)}}

另外,我认为你可以通过最后检查一个节点的左边或右面是否为空来对其进行优化,将对面的节点向上拉一个。类似的东西:

if node['left'] is empty:
    kid = node['right']
    node.clear()
    for k,v in kid.items():
        node[k]=v
elif node['right'] is empty:
    same basic thing, with left kid

检查空是一招,因为有时它是一个词典,有时候不是。

最后,您似乎并未存储实际的拆分信息。不是决策树的重点 - 知道要比较的因素是什么?难道你不记录每个节点的列和值吗?