id3决策树程序

时间:2013-08-23 12:21:58

标签: machine-learning id3 decision-tree

请你检查一下这个程序会在netbeans没有任何错误的情况下运行。这个程序的输出会是什么。尝试运行这个程序但是我得到的输出只读取我输入文件的数据.arff file.im得到如下输出:读取0数据 最后一行:@attribute handicapped-infants {String,n,y} 期待2个属性

import java.io.*;

import java.util.*;

public class ID3

{

int numAttributes;

String []attributeNames;

Vector []domains;

/* The class to represent a data point consisting of numAttributes values

of attributes */

class DataPoint {

public int []attributes;

public DataPoint(int numattributes) {

attributes = new int[numattributes];

}

};

/* The class to represent a node in the decomposition tree.

*/

class TreeNode {

public double entropy;

public Vector data;

public int decompositionAttribute;

public int decompositionValue;

public TreeNode []children;

public TreeNode parent;

public TreeNode() {

data = new Vector();

}

};

TreeNode root = new TreeNode();

public int getSymbolValue(int attribute, String symbol) {

int index = domains[attribute].indexOf(symbol);

if (index < 0) {

domains[attribute].addElement(symbol);

return domains[attribute].size() -1;

}

return index;

}

public int []getAllValues(Vector data, int attribute) {

Vector values = new Vector();

int num = data.size();

for (int i=0; i< num; i++) {

DataPoint point = (DataPoint)data.elementAt(i);

String symbol =

(String)domains[attribute].elementAt(point.attributes[attribute] );

int index = values.indexOf(symbol);

if (index < 0) {

values.addElement(symbol);

}

}

int []array = new int[values.size()];

for (int i=0; i< array.length; i++) {

String symbol = (String)values.elementAt(i);

array[i] = domains[attribute].indexOf(symbol);

}

values = null;

return array;

}

public Vector getSubset(Vector data, int attribute, int value) {

Vector subset = new Vector();

int num = data.size();

for (int i=0; i< num; i++) {

DataPoint point = (DataPoint)data.elementAt(i);

if (point.attributes[attribute] == value) subset.addElement(point);

}

return subset;

}

public double calculateEntropy(Vector data) {

int numdata = data.size();

if (numdata == 0) return 0;

int attribute = numAttributes-1;

int numvalues = domains[attribute].size();

double sum = 0;

for (int i=0; i< numvalues; i++) {

int count=0;

for (int j=0; j< numdata; j++) {

DataPoint point = (DataPoint)data.elementAt(j);

if (point.attributes[attribute] == i) count++;

}

double probability = 1.*count/numdata;

if (count > 0) sum += -probability*Math.log(probability);

}

return sum;

}

public boolean alreadyUsedToDecompose(TreeNode node, int attribute) {

if (node.children != null) {

if (node.decompositionAttribute == attribute )

return true;

}

if (node.parent == null) return false;

return alreadyUsedToDecompose(node.parent, attribute);

}

public void decomposeNode(TreeNode node) {

double bestEntropy=0;

boolean selected=false;

int selectedAttribute=0;

int numdata = node.data.size();

int numinputattributes = numAttributes-1;

node.entropy = calculateEntropy(node.data);

if (node.entropy == 0) return;

for (int i=0; i< numinputattributes; i++) {

int numvalues = domains[i].size();

if ( alreadyUsedToDecompose(node, i) ) continue;

double averageentropy = 0;

for (int j=0; j< numvalues; j++) {

Vector subset = getSubset(node.data, i, j);

if (subset.size() == 0) continue;

double subentropy = calculateEntropy(subset);

averageentropy += subentropy *

subset.size();

}

averageentropy = averageentropy / numdata; //

Taking the weighted average

if (selected == false) {

selected = true;

bestEntropy = averageentropy;

selectedAttribute = i;

} else {

if (averageentropy < bestEntropy) {

selected = true;

bestEntropy = averageentropy;

selectedAttribute = i;

}

}

}

if (selected == false) return;

int numvalues = domains[selectedAttribute].size();

node.decompositionAttribute = selectedAttribute;

node.children = new TreeNode [numvalues];

for (int j=0; j< numvalues; j++) {

node.children[j] = new TreeNode();

node.children[j].parent = node;

node.children[j].data = getSubset(node.data,

selectedAttribute, j);

node.children[j].decompositionValue = j;

}

for (int j=0; j< numvalues; j++) {

decomposeNode(node.children[j]);

}

node.data = null;

}

public int readData(String filename) throws Exception {

FileInputStream in = null;

try {

File inputFile = new File(filename);

in = new FileInputStream(inputFile);

} catch ( Exception e) {

System.err.println( "Unable to open data file: " + filename + "n" + e);

return 0;

}

BufferedReader bin = new BufferedReader(new InputStreamReader(in) );

String input;

while(true) {

input = bin.readLine();

if (input == null) {

System.err.println( "No data found in the data file: " + filename +

"n");

return 0;

}

if (input.startsWith("//")) continue;

if (input.equals("")) continue;

break;

}

StringTokenizer tokenizer = new StringTokenizer(input);

numAttributes = tokenizer.countTokens();

if (numAttributes <= 1) {

System.err.println( "Read line: " + input);

System.err.println( "Could not obtain the names of attributes in the

line");

System.err.println( "Expecting at least one input attribute and one

output attribute");

return 0;

}

domains = new Vector[numAttributes];

for (int i=0; i < numAttributes; i++) domains[i] = new Vector();

attributeNames = new String[numAttributes];

for (int i=0; i < numAttributes; i++) {

attributeNames[i] = tokenizer.nextToken();

}

while(true) {

input = bin.readLine();

if (input == null) break;

if (input.startsWith("//")) continue;

if (input.equals("")) continue;

tokenizer = new StringTokenizer(input);

int numtokens = tokenizer.countTokens();

if (numtokens != numAttributes) {

System.err.println( "Read " + root.data.size() + " data");

System.err.println( "Last line read: " + input);

System.err.println( "Expecting " + numAttributes + " attributes");

return 0;

}

DataPoint point = new DataPoint(numAttributes);

for (int i=0; i < numAttributes; i++) {

point.attributes[i] = getSymbolValue(i, tokenizer.nextToken()

);

}

root.data.addElement(point);

}

bin.close();

return 1;

}

public void printTree(TreeNode node, String tab) {

int outputattr = numAttributes-1;

if (node.children == null) {

int []values = getAllValues(node.data, outputattr );

if (values.length == 1) {

System.out.println(tab + "t" + attributeNames[outputattr] + " = "" +

domains[outputattr].elementAt(values[0]) + "";");

return;

}

System.out.print(tab + "t" + attributeNames[outputattr] + " = {");

for (int i=0; i < values.length; i++) {

System.out.print(""" + domains[outputattr].elementAt(values[i]) + ""

");

if ( i != values.length-1 ) System.out.print( " , " );

}

System.out.println( " };");

return;

}

int numvalues = node.children.length;

for (int i=0; i < numvalues; i++) {

System.out.println(tab + "if( " +

attributeNames[node.decompositionAttribute] + " == "" +

domains[node.decompositionAttribute].elementAt(i)

+ "") {" );

printTree(node.children[i], tab + "t");

if (i != numvalues-1) System.out.print(tab + "} else ");

else System.out.println(tab + "}");

}

}

public void createDecisionTree() {

decomposeNode(root);

printTree(root, "");

}

/* main function */

public static void main(String[] args) throws Exception {

ID3 me = new ID3();

int status = me.readData("c:\\in.txt");

if (status <= 0) return;

me.createDecisionTree();

}

}

1 个答案:

答案 0 :(得分:1)

输出将是一棵树。但是,您需要传递一个至少包含2个属性的文件。