我正在尝试在Map Reduce上实现c45算法,这里的代码只生成给定一些训练数据的规则集。
此类包含主要方法。
public class DecisionTreec45 extends Configured implements Tool
{
public static MySplit currentsplit=new MySplit();
public static List <MySplit> splitted=new ArrayList<MySplit>();
public static int current_index=0;
public static void main(String[] args) throws Exception {
MyMapper mp=new MyMapper();
splitted.add(currentsplit);
int res=0;
// double bestGain=0;
// boolean stop=true;
// boolean outerStop=true;
int split_index=0;
double gainratio=0;
double best_gainratio=0;
double entropy=0;
String classLabel=null;
int total_attributes=mp.no_Attr;
total_attributes=4;
int split_size=splitted.size();
MyGainRatio gainObj;
MySplit newnode;
while(split_size>current_index)
{
currentsplit = (MySplit) splitted.get(current_index);
gainObj = new MyGainRatio();
res = ToolRunner.run(new Configuration(), new DecisionTreec45(), args);
System.out.println("Current NODE INDEX . ::"+current_index);
int j=0;
int temp_size;
gainObj.getcount();
entropy=gainObj.currNodeEntophy();
classLabel=gainObj.majorityLabel();
currentsplit.classLabel=classLabel;
if(entropy!=0.0 && currentsplit.attr_index.size()!=total_attributes)
{
System.out.println("");
System.out.println("Entropy NOTT zero SPLIT INDEX:: "+entropy);
best_gainratio=0;
for(j=0;j<total_attributes;j++) //Finding the gain of each attribute
{
if(currentsplit.attr_index.contains(j)) // Splitting all ready done with this attribute
{
// System.out.println("Splitting all ready done with index "+j);
}
else
{
gainratio=gainObj.gainratio(j,entropy);
if(gainratio>=best_gainratio)
{
split_index=j;
best_gainratio=gainratio;
}
}
}
String attr_values_split=gainObj.getvalues(split_index);
StringTokenizer attrs = new StringTokenizer(attr_values_split);
int number_splits=attrs.countTokens(); //number of splits possible with attribute selected
String red="";
// int tred=-1;
System.out.println(" INDEX :: "+split_index);
System.out.println(" SPLITTING VALUES "+attr_values_split);
for(int splitnumber=1;splitnumber<=number_splits;splitnumber++)
{
temp_size=currentsplit.attr_index.size();
newnode=new MySplit();
for(int y=0;y<temp_size;y++) // CLONING OBJECT CURRENT NODE
{
newnode.attr_index.add(currentsplit.attr_index.get(y));
newnode.attr_value.add(currentsplit.attr_value.get(y));
}
red=attrs.nextToken();
newnode.attr_index.add(split_index);
newnode.attr_value.add(red);
splitted.add(newnode);
}
}
else
{
System.out.println("");
String rule="";
temp_size=currentsplit.attr_index.size();
for(int val=0;val<temp_size;val++)
{
rule=rule+" "+currentsplit.attr_index.get(val)+" "+currentsplit.attr_value.get(val);
}
rule=rule+" "+currentsplit.classLabel;
writeRuleToFile(rule);
if(entropy!=0.0)
System.out.println("Enter rule in file:: "+rule);
else
System.out.println("Enter rule in file Entropy zero :: "+rule);
}
split_size=splitted.size();
System.out.println("TOTAL NODES:: "+split_size);
current_index++;
}
System.out.println("COMPLETE");
System.exit(res);
}
public static void writeRuleToFile(String text)
{
try {
BufferedWriter bw = new BufferedWriter(new FileWriter(new File("/home/hduser/C45/rule.txt/"), true));
bw.write(text);
bw.newLine();
bw.close();
} catch (Exception e) {
}
}
public int run(String[] args) throws Exception
{
JobConf conf = new JobConf(getConf(),DecisionTreec45.class);
conf.setJobName("c4.5");
// the keys are words (strings)
conf.setOutputKeyClass(Text.class);
// the values are counts (ints)
conf.setOutputValueClass(IntWritable.class);
conf.setMapperClass(MyMapper.class);
conf.setReducerClass(MyReducer.class);
//set your input file path below
FileInputFormat.setInputPaths(conf, "/home/hduser/Id3_hds/playtennis.txt");
FileOutputFormat.setOutputPath(conf, new Path("/home/hduser/Id3_hds/1/output"+current_index));
JobClient.runJob(conf);
return 0;
}
}
此类用于计算增益比。
public class MyGainRatio
{
int linenumber=0;
static String count[][]=new String[10000][4];
int currnode[]=new int[100];
String majorityLabel=null;
public String majorityLabel()
{
return majorityLabel;
}
//Calculation of entrophy
public double currNodeEntophy()
{
int currentindex=0;
double entropy=0;
currentindex=Integer.parseInt(count[0][0]);
int i=0;
int covered[]=new int[1000];
String classLabel=count[0][2];
int j=0;
int ind=-1;
int maxStrength=0;
System.out.println("Values in node rep to classwise");
while(currentindex==Integer.parseInt(count[j][0]))
{
if(covered[j]==0)
{
classLabel=count[j][2];
ind++;
i=j;
while(currentindex==Integer.parseInt(count[i][0]))
{
if(covered[i]==0)
{
if(classLabel.contentEquals(count[i][2]))
{
currnode[ind] = currnode[ind]+Integer.parseInt(count[i][3]);
covered[i]=1;
}
}
i++;
if(i==linenumber)
break;
}
if(currnode[ind]>maxStrength)
{
maxStrength=currnode[ind];
majorityLabel=classLabel;
}
System.out.print(" "+classLabel+" "+currnode[ind]);
}
else
{
j++;
}
if(j==linenumber)
break;
}
entropy=entropy(currnode);
return entropy;
}
public double entropy(int c[])
{
double entropy=0;
int i=0;
int sum=0;
double frac;
while(c[i]!=0)
{
sum=sum+c[i];
i++;
}
i=0;
while(c[i]!=0)
{
frac=(double)c[i]/sum;
entropy=entropy-frac*(Math.log(frac)/Math.log(2));
i++;
}
return entropy;
}
public void getcount()
{
DecisionTreec45 id=new DecisionTreec45();
FileInputStream fstream;
try {
fstream = new FileInputStream("/home/hduser/C45/output/intermediate" + id.current_index + ".txt");
DataInputStream in = new DataInputStream(fstream);
BufferedReader br = new BufferedReader(new InputStreamReader(in));
String line;
//Read File Line By Line
StringTokenizer itr;
// System.out.println("READING FROM intermediate "+id.current_index);
while ((line = br.readLine()) != null) {
itr= new StringTokenizer(line);
count[linenumber][0]=itr.nextToken();
count[linenumber][1]=itr.nextToken();
count[linenumber][2]=itr.nextToken();
count[linenumber][3]=itr.nextToken();
int i=linenumber;
linenumber++;
}
count[linenumber][0]=null;
count[linenumber][1]=null;
count[linenumber][2]=null;
count[linenumber][3]=null;
in.close();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
//Close the input stream
}
}
public double gainratio(int index,double enp)
{
//100 is considered as max ClassLabels
int c[][]=new int[1000][100];
int sum[]=new int[1000]; //
String currentatrrval="@3#441get";
double gainratio=0;
int j=0;
int m=-1; //index for split number
int lines=linenumber;
int totalsum=0;
for(int i=0;i<lines;i++)
{
if(Integer.parseInt(count[i][0])==index)
{
if(count[i][1].contentEquals(currentatrrval))
{
j++;
c[m][j]=Integer.parseInt(count[i][3]);
sum[m]=sum[m]+c[m][j];
}
else
{
j=0;
m++;
currentatrrval=count[i][1];
c[m][j]=Integer.parseInt(count[i][3]); //(different class) data sets count per m index split
sum[m]=c[m][j];
}
}
}
int p=0;
while(sum[p]!=0)
{
totalsum=totalsum+sum[p]; //calculating total instance in node
p++;
}
double wtenp=0;
double splitenp=0;
double part=0;
for(int splitnum=0;splitnum<=m;splitnum++)
{
part=(double)sum[splitnum]/totalsum;
wtenp=wtenp+part*entropy(c[splitnum]);
}
splitenp=entropy(sum);
gainratio=(enp-wtenp)/(splitenp);
return gainratio;
}
public String getvalues(int n)
{ int flag=0;
String values="";
String temp="%%%%%!!@";
for(int z=0;z<1000;z++)
{
if(count[z][0]!=null)
{
if(n==Integer.parseInt(count[z][0]))
{
flag=1;
if(count[z][1].contentEquals(temp))
{
// System.out.println("Equals COUNT Index z "+z+" "+count[z][1]+ "temp "+temp);
}
else
{
values=values+" "+count[z][1];
temp=count[z][1];
}
}
else if(flag==1)
break;
}
else
break;
}
return values;
}
}
此映射器类检查此实例是否属于当前节点。 对于所有未覆盖的属性,它输出索引及其值 和实例的类标签。
public class MyMapper extends MapReduceBase
implements Mapper<LongWritable, Text, Text, IntWritable> {
private final static IntWritable one = new IntWritable(1);
private Text attValue = new Text();
private Text cLabel = new Text();
private int i;
private String token;
public static int no_Attr;
//public static int splitAttr[];
private int flag=0;
public void map(LongWritable key, Text value,OutputCollector<Text, IntWritable> output,Reporter reporter) throws IOException {
DecisionTreec45 id=new DecisionTreec45();
MySplit split=null;
int size_split=0;
split=id.currentsplit;
String line = value.toString(); //changing input instance value to string
StringTokenizer itr = new StringTokenizer(line);
int index=0;
String attr_value=null;
no_Attr=itr.countTokens()-1;
String attr[]=new String[no_Attr];
boolean match=true;
for(i=0;i<no_Attr;i++)
{
attr[i]=itr.nextToken(); //Finding the values of different attributes
}
String classLabel=itr.nextToken();
size_split=split.attr_index.size();
for(int count=0;count<size_split;count++)
{
index=(Integer) split.attr_index.get(count);
attr_value=(String)split.attr_value.get(count);
if(attr[index].equals(attr_value)) //may also use attr[index][z][1].contentEquals(attr_value)
{
//System.out.println("EQUALS IN MAP nodes "+attr[index]+" inline "+attr_value);
}
else
{
// System.out.println("NOT EQUAL IN MAP nodes "+attr[index]+" inline "+attr_value);
match=false;
break;
}
}
//id.attr_count=new int[no_Attr];
if(match)
{
for(int l=0;l<no_Attr;l++)
{
if(split.attr_index.contains(l))
{
}
else
{
token=l+" "+attr[l]+" "+classLabel;
attValue.set(token);
output.collect(attValue, one);
}
}
if(size_split==no_Attr)
{
token=no_Attr+" "+"null"+" "+classLabel;
attValue.set(token);
output.collect(attValue, one);
}
}
}
}
此类计算组合的发生次数(索引和 它的值和类标签)和打印计数。
public class MyReducer extends MapReduceBase implements
Reducer<Text, IntWritable, Text, IntWritable> {
public void reduce(Text key, Iterator<IntWritable> values,
OutputCollector<Text, IntWritable> output, Reporter reporter)
throws IOException {
int sum = 0;
String line = key.toString();
// StringTokenizer itr = new StringTokenizer(line);
while (values.hasNext()) {
sum += values.next().get();
}
output.collect(key, new IntWritable(sum));
writeToFile(key + " " + sum);
/*
* int index=Integer.parseInt(itr.nextToken()); String
* value=itr.nextToken(); String classLabel=itr.nextToken(); int
* count=sum;
*/
}
public static void writeToFile(String text) {
try
{
DecisionTreec45 id = new DecisionTreec45();
BufferedWriter bw = new BufferedWriter(new FileWriter(new File(
"/home/hduser/C45/output/intermediate" + id.current_index
+ ".txt"), true));
bw.write(text);
bw.newLine();
bw.close();
}
catch (Exception e)
{
}
}
}
此类拆分属性
public class MySplit implements Cloneable
{
public List attr_index;
public List attr_value;
double entophy;
String classLabel;
MySplit()
{
this.attr_index= new ArrayList<Integer>();
this.attr_value = new ArrayList<String>();
}
MySplit(List attr_index,List attr_value)
{
this.attr_index=attr_index;
this.attr_value=attr_value;
}
void add(MySplit obj)
{
this.add(obj);
}
}
输入训练数据集如下所示:&gt;
sunny hot high weak no
sunny hot high strong no
overcast hot high weak yes
rain mild high weak yes
rain cool normal weak yes
rain cool normal strong no
overcast cool normal strong yes
sunny mild high weak no
sunny cool normal weak yes
rain mild normal weak yes
sunny mild normal strong yes
overcast mild high strong yes
overcast hot normal weak yes
rain mild high strong no