RandomWorest with Java中的Weka

时间:2015-11-25 12:58:39

标签: java weka

我正在研究一个项目,我需要一些例子来说明如何用weka在Java中实现RandomForest?我用IBk()做到了,它奏效了。如果我以同样的方式使用RandomForest,它就不起作用。 有没有人有一个简单的例子让我知道如何实现RandomForest以及如何获得每个类的概率(我使用带有classifier.distributionForInstance(instance)函数的IBk做了它并且它返回了每个类的概率)。我如何为RandomForest做到这一点?我需要获得每棵树的概率并将其组合起来吗?

//example

ConverrterUtils.DataSource source = new ConverterUtils.DataSource ("..../edit.arff); 
Instances dataset = source.getDataSet();
dataset.setClassIndex(dataset.numAttributes() - 1); 
IBk classifier = new IBk(5); classifier.buildClassifier(dataset);

Instance instance = new SparseInstance(2); 
instance.setValue(0, 65)   //example data 
instance.setValue(1, 120);   //example data 
double[] prediction = classifier.distributionForInstance(instance);

//now I get the probability for the first class   
System.out.println("Prediction for the first class is: "+prediction[0]);

1 个答案:

答案 0 :(得分:3)

您可以在RandomForest中建立模型时计算 infogain 。它更慢,在建模时需要大量内存。我对文档不太确定。你可以在buiilding模型时添加选项或setValues。

    //numFolds in number of crossvalidations usually between 1-10

    //br is your bufferReader
    Instances trainData = new Instances(br);
    trainData.setClassIndex(trainData.numAttributes() - 1);

    RandomForest rf = new RandomForest();
    rf.setNumTrees(50);

    //You can set the options here
    String[] options = new String[2];
    options[0] = "-R";                
    rf.setOptions(options);

    rf.buildClassifier(trainData);


    weka.filters.supervised.attribute.AttributeSelection as = new  weka.filters.supervised.attribute.AttributeSelection();
    Ranker ranker = new Ranker();



    InfoGainAttributeEval infoGainAttrEval = new InfoGainAttributeEval();
    as.setEvaluator(infoGainAttrEval);
    as.setSearch(ranker);
    as.setInputFormat(trainData);
    trainData = Filter.useFilter(trainData, as);

    Evaluation evaluation = new Evaluation(trainData);
    evaluation.crossValidateModel(rf, trainData, numFolds, new Random(1));


    // Using HashMap to store the infogain values of the attributes 
    int count = 0;
    Map<String, Double> infogainscores = new HashMap<String, Double>();

    for (int i = 0; i < trainData.numAttributes(); i++) {
        String t_attr = trainData.attribute(i).name();
        //System.out.println(i+trainData.attribute(i).name());
        double infogain  = infoGainAttrEval.evaluateAttribute(i);
        if(infogain != 0){
                //System.out.println(t_attr +  "= "+ infogain); 
                infogainscores.put(t_attr, infogain);
                count = count+1;
       }
    }

    //iterating over the hashmap

    Iterator it = infogainscores.entrySet().iterator();
    while (it.hasNext()) {
       Map.Entry pair = (Map.Entry)it.next();
       System.out.println(pair.getKey()+"  =  "+pair.getValue());
       System.out.println(pair.getKey()+"  =  "+pair.getValue());
       it.remove(); // avoids a ConcurrentModificationException
    }