如何在c45中生成混淆矩阵?

时间:2015-04-17 09:03:50

标签: java hadoop mapreduce decision-tree c4.5

我正在尝试在Map Reduce上实现c45算法,这里的代码只生成给定一些训练数据的规则集。

此类包含主要方法。

public class DecisionTreec45 extends Configured implements Tool 
{
    public static MySplit currentsplit=new MySplit();

    public static List <MySplit> splitted=new ArrayList<MySplit>();

    public static int current_index=0;

    public static void main(String[] args) throws Exception {

        MyMapper mp=new MyMapper();
        splitted.add(currentsplit);

        int res=0;
    //  double bestGain=0;
    //  boolean stop=true;
    //  boolean outerStop=true;
        int split_index=0;
        double gainratio=0;
        double best_gainratio=0;
        double entropy=0;
        String classLabel=null;
        int total_attributes=mp.no_Attr;
        total_attributes=4;
        int split_size=splitted.size();
        MyGainRatio gainObj;
        MySplit newnode;

        while(split_size>current_index)
        {
            currentsplit = (MySplit) splitted.get(current_index); 
            gainObj = new MyGainRatio();

            res = ToolRunner.run(new Configuration(), new DecisionTreec45(), args);
            System.out.println("Current  NODE INDEX . ::"+current_index);

            int j=0;
            int temp_size;
            gainObj.getcount();
            entropy=gainObj.currNodeEntophy();
            classLabel=gainObj.majorityLabel();
            currentsplit.classLabel=classLabel;

            if(entropy!=0.0 && currentsplit.attr_index.size()!=total_attributes)
            {
                System.out.println("");
                System.out.println("Entropy  NOTT zero   SPLIT INDEX::    "+entropy);

                best_gainratio=0;

                for(j=0;j<total_attributes;j++)     //Finding the gain of each attribute
                {
                    if(currentsplit.attr_index.contains(j))  // Splitting all ready done with this attribute
                    {
                        // System.out.println("Splitting all ready done with  index  "+j);
                    }
                    else
                    {
                        gainratio=gainObj.gainratio(j,entropy);

                        if(gainratio>=best_gainratio)
                        {
                            split_index=j;
                            best_gainratio=gainratio;
                        }
                    }
                }
                String attr_values_split=gainObj.getvalues(split_index);
                StringTokenizer attrs = new StringTokenizer(attr_values_split);
                int number_splits=attrs.countTokens(); //number of splits possible with  attribute selected
                String red="";
                //  int tred=-1;

                System.out.println(" INDEX ::  "+split_index);
                System.out.println(" SPLITTING VALUES  "+attr_values_split);

                for(int splitnumber=1;splitnumber<=number_splits;splitnumber++)
                {
                    temp_size=currentsplit.attr_index.size();
                    newnode=new MySplit(); 
                    for(int y=0;y<temp_size;y++)   // CLONING OBJECT CURRENT NODE
                    {
                        newnode.attr_index.add(currentsplit.attr_index.get(y));
                        newnode.attr_value.add(currentsplit.attr_value.get(y));
                    }
                    red=attrs.nextToken();

                    newnode.attr_index.add(split_index);
                    newnode.attr_value.add(red);
                    splitted.add(newnode);
                }
            }
            else
            {
                System.out.println("");
                String rule="";
                temp_size=currentsplit.attr_index.size();
                for(int val=0;val<temp_size;val++)  
                {
                    rule=rule+" "+currentsplit.attr_index.get(val)+" "+currentsplit.attr_value.get(val);
                }
                rule=rule+" "+currentsplit.classLabel;
                writeRuleToFile(rule);
                if(entropy!=0.0)
                    System.out.println("Enter rule in file:: "+rule);
                else
                    System.out.println("Enter rule in file Entropy zero ::   "+rule);
            }

            split_size=splitted.size();
            System.out.println("TOTAL NODES::    "+split_size);

            current_index++;
        }
        System.out.println("COMPLETE");
        System.exit(res);
    }

    public static void writeRuleToFile(String text) 
    {
        try {  
            BufferedWriter bw = new BufferedWriter(new FileWriter(new File("/home/hduser/C45/rule.txt/"), true));    
            bw.write(text);
            bw.newLine();
            bw.close();
        } catch (Exception e) {
        }
    }

    public int run(String[] args) throws Exception 
    {
        JobConf conf = new JobConf(getConf(),DecisionTreec45.class);
        conf.setJobName("c4.5");

        // the keys are words (strings)
        conf.setOutputKeyClass(Text.class);
        // the values are counts (ints)
        conf.setOutputValueClass(IntWritable.class);
        conf.setMapperClass(MyMapper.class);
        conf.setReducerClass(MyReducer.class);

        //set your input file path below
        FileInputFormat.setInputPaths(conf, "/home/hduser/Id3_hds/playtennis.txt");
        FileOutputFormat.setOutputPath(conf, new Path("/home/hduser/Id3_hds/1/output"+current_index));
        JobClient.runJob(conf);
        return 0;
    }
}

此类用于计算增益比。

public class MyGainRatio
{
    int linenumber=0;
    static String count[][]=new String[10000][4];
    int currnode[]=new int[100];
    String majorityLabel=null;
    public String majorityLabel()
    {
        return majorityLabel;
    }

    //Calculation of entrophy
    public double currNodeEntophy()
    {
        int currentindex=0;
        double entropy=0;
        currentindex=Integer.parseInt(count[0][0]);
        int i=0;
        int covered[]=new int[1000];
        String classLabel=count[0][2];
        int j=0;
        int ind=-1;
        int maxStrength=0;
        System.out.println("Values in node rep to classwise");
        while(currentindex==Integer.parseInt(count[j][0]))
        {
            if(covered[j]==0)
            {
                classLabel=count[j][2];
                ind++;
                i=j;
                while(currentindex==Integer.parseInt(count[i][0]))
                {
                    if(covered[i]==0)
                    {
                        if(classLabel.contentEquals(count[i][2]))
                        {
                            currnode[ind] = currnode[ind]+Integer.parseInt(count[i][3]);
                            covered[i]=1;
                        }
                    }
                    i++;
                    if(i==linenumber)
                        break;
                }
                if(currnode[ind]>maxStrength)
                {
                    maxStrength=currnode[ind];
                    majorityLabel=classLabel;
                }
                System.out.print("    "+classLabel+"    "+currnode[ind]);
            }
            else
            {
                j++;
            }
            if(j==linenumber)
                break;
          }
          entropy=entropy(currnode);
          return entropy;
      }

      public double entropy(int c[])
      {
          double entropy=0;
          int i=0;
          int sum=0;
          double frac;
          while(c[i]!=0)
          {
              sum=sum+c[i];
              i++;
          }
          i=0;
          while(c[i]!=0)
          {
              frac=(double)c[i]/sum;
              entropy=entropy-frac*(Math.log(frac)/Math.log(2));
              i++;
          }

          return entropy;
      }

      public void getcount()
      { 
          DecisionTreec45 id=new DecisionTreec45();
          FileInputStream fstream;
          try {
              fstream = new FileInputStream("/home/hduser/C45/output/intermediate" + id.current_index + ".txt");
              DataInputStream in = new DataInputStream(fstream);
              BufferedReader br = new BufferedReader(new InputStreamReader(in));
              String line;
              //Read File Line By Line
              StringTokenizer itr;
              // System.out.println("READING FROM intermediate  "+id.current_index);

              while ((line = br.readLine()) != null)   {
                  itr= new StringTokenizer(line);
                  count[linenumber][0]=itr.nextToken();
                  count[linenumber][1]=itr.nextToken();
                  count[linenumber][2]=itr.nextToken();
                  count[linenumber][3]=itr.nextToken();
                  int i=linenumber;

                  linenumber++;
              }
              count[linenumber][0]=null;
              count[linenumber][1]=null;
              count[linenumber][2]=null;
              count[linenumber][3]=null;
              in.close();
          } catch (Exception e) {
              // TODO Auto-generated catch block
              e.printStackTrace();
              //Close the input stream
          }
      }



      public double gainratio(int index,double enp)
      {

        //100 is considered as max ClassLabels
          int c[][]=new int[1000][100];
          int sum[]=new int[1000]; //
          String currentatrrval="@3#441get";
          double gainratio=0;
          int j=0;
          int m=-1;  //index for split number 
          int lines=linenumber;
          int totalsum=0;
          for(int i=0;i<lines;i++)
          {
              if(Integer.parseInt(count[i][0])==index)
              {


                  if(count[i][1].contentEquals(currentatrrval))
                  {
                  j++;
                  c[m][j]=Integer.parseInt(count[i][3]);
                  sum[m]=sum[m]+c[m][j];
                  }
                  else
                  {
                      j=0;
                      m++;
                      currentatrrval=count[i][1];
                      c[m][j]=Integer.parseInt(count[i][3]); //(different class) data sets count per m index split
                      sum[m]=c[m][j];
                  }

              }
          }
          int p=0;
          while(sum[p]!=0)
          {
          totalsum=totalsum+sum[p]; //calculating total instance in node
          p++;
          }

          double wtenp=0;
          double splitenp=0;
          double part=0;
          for(int splitnum=0;splitnum<=m;splitnum++)
          {
              part=(double)sum[splitnum]/totalsum;
             wtenp=wtenp+part*entropy(c[splitnum]);
          }
          splitenp=entropy(sum);
          gainratio=(enp-wtenp)/(splitenp);
          return gainratio;

      }


      public String getvalues(int n)
         {   int flag=0;
             String values="";
             String temp="%%%%%!!@";
             for(int z=0;z<1000;z++)
             {
              if(count[z][0]!=null)
              {
            if(n==Integer.parseInt(count[z][0]))
             {
                 flag=1;

                 if(count[z][1].contentEquals(temp))
                 {
               // System.out.println("Equals  COUNT  Index z "+z+"   "+count[z][1]+ "temp  "+temp);
                 }
                 else
                 {

                     values=values+" "+count[z][1];
                     temp=count[z][1];

                 }
             }
            else if(flag==1)
                break;
            }
            else
                break;
             }
             return values;

         }

}

此映射器类检查此实例是否属于当前节点。 对于所有未覆盖的属性,它输出索引及其值 和实例的类标签。

public class MyMapper extends MapReduceBase
implements Mapper<LongWritable, Text, Text, IntWritable> {

private final static IntWritable one = new IntWritable(1);
private Text attValue = new Text();
private Text cLabel = new Text();
private int i;
private String token;
public static int no_Attr;
//public static int splitAttr[];
private int flag=0;



public void map(LongWritable key, Text value,OutputCollector<Text, IntWritable> output,Reporter reporter) throws IOException {

  DecisionTreec45 id=new DecisionTreec45();
  MySplit split=null;
  int size_split=0;
  split=id.currentsplit;

  String line = value.toString();      //changing input instance value to string
  StringTokenizer itr = new StringTokenizer(line);
  int index=0;
  String attr_value=null;
  no_Attr=itr.countTokens()-1;
  String attr[]=new String[no_Attr];
  boolean match=true;
  for(i=0;i<no_Attr;i++)
  {
      attr[i]=itr.nextToken();      //Finding the values of different attributes
  }
  String classLabel=itr.nextToken();
  size_split=split.attr_index.size();
  for(int count=0;count<size_split;count++)
  {
      index=(Integer) split.attr_index.get(count);
      attr_value=(String)split.attr_value.get(count);
     if(attr[index].equals(attr_value))   //may also use attr[index][z][1].contentEquals(attr_value)
     {
         //System.out.println("EQUALS IN MAP  nodes  "+attr[index]+"   inline  "+attr_value);
     }
     else
     {
        // System.out.println("NOT EQUAL IN MAP  nodes  "+attr[index]+"   inline  "+attr_value);
         match=false;
         break;
     }

  }


  //id.attr_count=new int[no_Attr];

  if(match)
  {
      for(int l=0;l<no_Attr;l++)
      {  
          if(split.attr_index.contains(l))
          {

          }
          else
          {
              token=l+" "+attr[l]+" "+classLabel;
              attValue.set(token);
              output.collect(attValue, one);
          }

    }
      if(size_split==no_Attr)
      {
          token=no_Attr+" "+"null"+" "+classLabel;
          attValue.set(token);
          output.collect(attValue, one);
        }
   }
 } 
}

此类计算组合的发生次数(索引和 它的值和类标签)和打印计数。

public class MyReducer extends MapReduceBase implements
        Reducer<Text, IntWritable, Text, IntWritable> {
    public void reduce(Text key, Iterator<IntWritable> values,
            OutputCollector<Text, IntWritable> output, Reporter reporter)
            throws IOException {
        int sum = 0;
        String line = key.toString();
        // StringTokenizer itr = new StringTokenizer(line);
        while (values.hasNext()) {
            sum += values.next().get();
        }
        output.collect(key, new IntWritable(sum));
        writeToFile(key + " " + sum);
        /*
         * int index=Integer.parseInt(itr.nextToken()); String
         * value=itr.nextToken(); String classLabel=itr.nextToken(); int
         * count=sum;
         */
    }

    public static void writeToFile(String text) {
        try 
        {
            DecisionTreec45 id = new DecisionTreec45();
            BufferedWriter bw = new BufferedWriter(new FileWriter(new File(
                    "/home/hduser/C45/output/intermediate" + id.current_index
                            + ".txt"), true));
            bw.write(text);
            bw.newLine();
            bw.close();
        } 
        catch (Exception e) 
        {
        }
    }
}

此类拆分属性

public class MySplit implements Cloneable
{
    public List attr_index;
    public List attr_value;
    double entophy;
    String classLabel;
    MySplit()
    {
         this.attr_index= new ArrayList<Integer>();
         this.attr_value = new ArrayList<String>();
    }
    MySplit(List attr_index,List attr_value)
    {
        this.attr_index=attr_index;
        this.attr_value=attr_value;

    }

    void add(MySplit obj)
    {
        this.add(obj);
    }
}

输入训练数据集如下所示:&gt;

sunny hot high weak no

sunny hot high strong no

overcast hot high weak yes

rain mild high weak yes

rain cool normal weak yes

rain cool normal strong no

overcast cool normal strong yes

sunny mild high weak no

sunny cool normal weak yes

rain mild normal weak yes

sunny mild normal strong yes

overcast mild high strong yes

overcast hot normal weak yes

rain mild high strong no

0 个答案:

没有答案