在ID3决策树中选择具有属性数值的最佳节点

时间:2014-03-09 02:17:49

标签: python machine-learning decision-tree id3 arff

我有以下代码。当我的属性中没有任何数值来选择最佳属性时,它可以正常工作。但是,当我为我的属性(如年龄属性)提供数值时,我不确定如何修改我的代码。

from   arff_v2  import  *

import math

def entropy(data, target_attr):
    """
    Calculates the entropy of the given data set for the target attribute.
    """
    entropy=0
    pos_neg = {}

    # Calculate the frequency of each of the values in the target attr
    for record in data:
        if (pos_neg.has_key(record[target_attr])):
            pos_neg[record[target_attr]] += 1.0
            #print record[target_attr]  #for self-testing
            #print target_attr  #for self-testing
            #print pos_neg  #for self-testing
        else:
            pos_neg[record[target_attr]] = 1.0

    # Calculate the entropy of the data for the target attribute
    for np in pos_neg.values():
        entropy += (-np/len(data)) * math.log(np/len(data), 2) 

    return entropy

relation, comments, attrs, data = readArff('heart_train.arff')
#print  attrs  #for self-testing
#print len(attrs)  #for self-testing
ep=entropy(data,13)

#neg_pos={}  #for self-testing
#for record in data:  #for self-testing
#    print record[13]  #for self-testing
#    print neg_pos.has_key(record[13])  #for self-testing
#print ep
def mutual_information(data, attr, target_attr):
    """
    Calculates the information gain (reduction in entropy) that would
    result by splitting the data on the chosen attribute (attr).
    """
    subcol_entropy=0
    pos_neg = {}
    # Calculate the frequency of each of the values in the target attribute
    for record in data:
        if (pos_neg.has_key(record[attr])):
            pos_neg[record[attr]] += 1.0
        else:
            pos_neg[record[attr]] = 1.0

    # Calculate the sum of the entropy for each subset of records weighted
    # by their probability of occuring in the training set.
    #print pos_neg.keys() #for testing gives t and f
    print pos_neg
    for values in pos_neg.keys(): 
        pn_prob = pos_neg[values] / sum(pos_neg.values())
        subcol = [record for record in data if record[attr] == values]
        subcol_entropy += pn_prob * entropy(subcol, target_attr)
        #print pn_prob * entropy(subcol, target_attr) #for self-testing
    #print subcol_entropy #for self-testing
    mutual_information= entropy(data,target_attr) - subcol_entropy
    return(mutual_information)

mi=mutual_information(data,5,13)

#print "difference"
#print data[:]



def attr_select(data, attrs, target_attr):
    """
    return the attribute with highest information gain
    """
    best_mutual_information = 0.0
    best_attr = None
    count=0
    #consider count as attr (it should be an integer value and we can pass a dictionary
    for attr in attrs:
        mi = mutual_information(data, count, target_attr)
        if (count != target_attr):
            if (mi >= best_mutual_information):
                best_mutual_information = mi
                best_attr = count
        count+=1 #test the rest of the attributes
    return best_attr

ch=attr_select(data,attrs,13)
print ch
def retrieve_examples(data, attr, value):

    example_list = []
    if not data:
        return example_list
    else:
        record = data.pop()
        if record[attr] == value:
            example_list.append(record)
            example_list.extend(get_examples(data, attr, value))
            return example_list
        else:
            example_list.extend(get_examples(data, attr, value))
            return example_list

list=get_examples(data,1,'male')
#print list

这是我正在使用的数据。它是.arff格式:

    @relation cleveland-14-heart-disease
    @attribute 'age' real
    @attribute 'sex' { female, male}
    @attribute 'cp' { typ_angina, asympt, non_anginal, atyp_angina}
    @attribute 'trestbps' real
    @attribute 'chol' real
    @attribute 'fbs' { t, f}
    @attribute 'restecg' { left_vent_hyper, normal, st_t_wave_abnormality}
    @attribute 'thalach' real
    @attribute 'exang' { no, yes}
    @attribute 'oldpeak' real
    @attribute 'slope' { up, flat, down}
    @attribute 'ca' real
    @attribute 'thal' { fixed_defect, normal, reversable_defect}
    @attribute 'class' { negative, positive}
    @data
    63,male,typ_angina,145,233,t,left_vent_hyper,150,no,2.3,down,0,fixed_defect, positive
    37,male,non_anginal,130,250,f,normal,187,no,3.5,down,0,normal,negative
    41,female,atyp_angina,130,204,f,left_vent_hyper,172,no,1.4,up,0,normal,negative
    56,male,atyp_angina,120,236,f,normal,178,no,0.8,up,0,normal,negative
    57,female,asympt,120,354,f,normal,163,yes,0.6,up,0,normal,positive
    57,male,asympt,140,192,f,normal,148,no,0.4,flat,0,fixed_defect,negative
    56,female,atyp_angina,140,294,f,left_vent_hyper,153,no,1.3,flat,0,normal,negative
    44,male,atyp_angina,120,263,f,normal,173,no,0,up,0,reversable_defect,negative
    52,male,non_anginal,172,199,t,normal,162,no,0.5,up,0,reversable_defect,negative
    57,male,non_anginal,150,168,f,normal,174,no,1.6,up,0,normal,negative
    54,male,asympt,140,239,f,normal,160,no,1.2,up,0,normal,negative
    48,female,non_anginal,130,275,f,normal,139,no,0.2,up,0,normal,positive

这也是我在网上找到的一个arff解析器,工作正常。我放在同一个目录中:

from __future__ import division
"""
Operations on WEKA .arff files

Created on 28/09/2010
@author: peter
"""

import sys, re, os, datetime

def getAttributeByName_(attributes, name):
    """ Return attributes member with name <name> """
    for a in attributes:
        if a['name'] == name:
            return a
    return None

def showAttributeByName_(attributes, name, title):
    print '>>>', title, ':', getAttributeByName(attributes, name)

def debugAttributes(attributes, title):
    pass
    # showAttributeByName(attributes, 'Number.of.Successful.Grant', title)

def writeArff2(filename, comments, relation, attr_keys, attrs, data, make_copies = False):
    """ Write a WEKA .arff file 
    Params:
        filename: name of .arff file
        comments: free text comments 
        relation: name of data set
        attr_keys: gives order of keys in attrs to match columns
        attrs: dict of attribute: all values of attribute
        data: the actual data
    """
    assert(len(attr_keys) == len(attrs))
    assert(len(data[0]) == len(attrs))
    assert(len(attrs) >= 2)
    f = file(filename, 'w')
    f.write('\n')
    f.write('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n')
    f.write('%% %s \n' % os.path.basename(filename))
    f.write('%\n')
    f.write('% Created by ' + os.path.basename(sys.argv[0]) + ' on ' + datetime.date.today().strftime("%A, %d %B %Y") + '\n')
    f.write('% Code at http://bit.ly/read_arff\n')
    f.write('%\n')
    f.write('%% %d instances\n' % len(data))
    f.write('%% %d attributes + 1 class = %d columns\n' % (len(data[0]) - 1, len(data[0])))
    f.write('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n')
    f.write('\n')
    if comments:
        f.write('% Original comments\n')
        for c in comments:
            f.write(c + '\n')
    f.write('@RELATION ' + relation + '\n\n')
    for name in attr_keys:
        vals = attrs[name]
        if type(vals) is str:
            attrs_str = vals
        else:
            attrs_str = '{%s}' % ','.join([x for x in vals if not x == '?'])
        f.write('@ATTRIBUTE %-15s %s\n' % (name, attrs_str))
    f.write('\n@DATA\n\n')
    for instance in data:
        instance = ['?' if x == '' else x for x in instance]
        for i,name in enumerate(attr_keys):
            if type(attrs) is list:
                assert(instance[i] in attrs[name]+ ['?'])
        f.write(', '.join(instance) + '\n')
        #print ', '.join(instance)
    f.close()

    #print attr_keys[0], attrs[attr_keys[0]]
    #exit()

    if make_copies:
        """ Copy .arff files to .arff.txt so they can be viewed from Google docs """
        print 'writeArff:', filename + '.txt', '-- duplicate'
        shutil.copyfile(filename, filename + '.txt')

def writeArff(filename, comments, relation, attrs_list, data, make_copies = False, test = True):
    """ Write a WEKA .arff file 
    Params:
        filename: name of .arff file
        comments: free text comments 
        relation: name of data set
        attrs_list: list of dicts of attribute: all values of attribute
        data: the actual data
    """
    assert(len(attrs_list) > 0)
    assert(len(data) > 0)
    debugAttributes(attrs_list, 'writeArff')
    attr_keys = [x['name'] for x in attrs_list]
    attrs_dict = {}
    for x in attrs_list:
        attrs_dict[x['name']] = x['vals']
    writeArff2(filename, comments, relation, attr_keys, attrs_dict, data, make_copies)

    if test:
        out_relation, out_comments, out_attrs_list, out_data = readArff(filename)
        if out_attrs_list != attrs_list:
            print 'len(out_attrs_list) =', len(out_attrs_list), ', len(attrs_list) =', len(attrs_list)
            if len(out_attrs_list) == len(attrs_list):
                for i in range(len(attrs_list)):
                    print '%3d:' % i, out_attrs_list[i], attrs_list[i]
        assert(out_relation == relation)
        assert(out_attrs_list == attrs_list)
        assert(out_data == data)

def getRe(pattern, text):
    return re.findall(pattern, text)

relation_pattern = re.compile(r'@RELATION\s*(\S+)\s*$', re.IGNORECASE)
attr_name_pattern = re.compile(r'@ATTRIBUTE\s*(\S+)\s*', re.IGNORECASE)
attr_type_pattern = re.compile(r'@ATTRIBUTE\s*\S+\s*(\S+)', re.IGNORECASE)
attr_vals_pattern = re.compile(r'\{\s*(.+)\s*\}', re.IGNORECASE)
csv_pattern = re.compile(r'(?:^|,)(\"(?:[^\"]+|\"\")*\"|[^,]*)', re.IGNORECASE)

def readArff(filename):
    """ Read a WEKA .arff file
    Params: 
        filename: name of .arff file
    Returns:
        comments: free text comments 
        relation: name of data set
        attrs: list of attributes
        data: the actual data
    """
    print 'readArff(%s)' % filename

    lines = file(filename).readlines()
    lines = [l.rstrip('\n').strip() for l in lines]
    lines = [l for l in lines if len(l)]

    comments = [l for l in lines if l[0] == '%']
    lines = [l for l in lines if not l[0] == '%']

    relation = [l for l in lines if '@RELATION' in l.upper()]
    attributes = [l for l in lines if '@ATTRIBUTE' in l.upper()]

    #for i,a in enumerate(attributes[8:12]):
    #    print '%4d' % (8+i), a

    data = []
    in_data = False
    for l in lines:
        if in_data:
            data.append(l)
        elif '@DATA' in l.upper():
            in_data = True

    #print 'relation =', relation
    out_relation = getRe(relation_pattern, relation[0])[0]

    out_attrs = []

    for l in attributes:
        name = getRe(attr_name_pattern, l)[0]
        if not '{' in l:
            vals_string = getRe(attr_type_pattern, l)[0].strip()
            vals = vals_string.strip()
        else:
            vals_string = getRe(attr_vals_pattern, l)[0]
            vals = [x.strip() for x in vals_string.split(',')]
        out_attrs.append({'name':name, 'vals':vals})
        if False:
            print name, vals
            if name == 'Number.of.Successful.Grant':
                exit()

    #print 'out_attrs:', out_attrs
    out_data = []
    for l in data:
        out_data.append([x.strip() for x in getRe(csv_pattern, l)])
    for d in out_data:
        assert(len(out_attrs) == len(d))

    debugAttributes(out_attrs, 'readArff')

    return (out_relation, comments, out_attrs, out_data)

def testCsv():
    if len(sys.argv) != 2:
        print "Usage: arff.py <arff-file>"
        sys.exit()

    in_file_name = sys.argv[1]
    out_file_name = os.path.splitext(in_file_name)[0] + '.copy' + os.path.splitext(in_file_name)[1]

    print 'Reading', in_file_name
    print 'Writing', out_file_name

    relation, comments, attrs, data = readArff(in_file_name)
    writeArff(out_file_name, comments, relation, attrs, data)

if __name__ == '__main__':
    if True:
        line = '1,a,"x,y",q'
        pattern = '(?:^|,)(\\\"(?:[^\\\"]+|\\\"\\\")*\\\"|[^,]*)'
        patter2 = r'(?:^|,)(\"(?:[^\"]+|\"\")*\"|[^,]*)'
        print pattern
        print patter2
        assert(patter2 == pattern)
        vals = re.findall(pattern, line)
        print pattern
        print line
        print vals

    if True:
        testCsv()

1 个答案:

答案 0 :(得分:2)

将年龄分成小组,然后它可以表示为男/女,所以你有例如'22 -30','31 -40'等