对所有实例进行分类,并获得Weka中正类标签的最高概率

时间:2015-03-05 03:51:21

标签: java classification weka

我创建了weka arff实例。我必须强制每个数据集只有一个正面标签。这可以通过这种方式完成:

For all instances:
   get the probability of predicted class label 1,
   choose the highest probability as 1 and others 0
   if there is no class label 1:
      get lowest probability of predicted class 0
      label that as predicted class 1
      other instances prediction will be 0
   for the instance that is equal to 1, see the actual value of it's class, if it is the same, then score=score+1. 

我可以在weka中处理分类:

DataSource source = new DataSource(outputFolderPath + "/" + fileName + ".arff");
Instances data = source.getDataSet();
if (data.classIndex() == -1)
data.setClassIndex(data.numAttributes() - 1);

Classifier cls = (Classifier) weka.core.SerializationHelper.read(mainPath+"meta.model");

double prediction=cls.classifyInstance(data.instance(0));
String actual_label=data.classAttribute().value((int)value); 

如何对所有实例进行分类,如何通过获取概率来实现我想要的结果? 我用这种方式找到了子解决方案:

for (int j = 0; j < data_test.numInstances(); j++) {
                double prediction=cls.classifyInstance(data.instance(j));
                double[] prob=cls.distributionForInstance(data_test.instance(j));
                //prob[0] is probability of class 0 and prob[1] is probability of class1 

            }

现在问题只出现在第二部分。如何找到1级的最高概率。

Update2:我试图在数组中存储类1的值,预测和概率,并根据概率对它们进行排序。

The method sort(T[], Comparator<? super T>) in the type Arrays is not applicable for the arguments (double[][], new Comparator<Double[]>(){})

来自此代码:

 for (int j = 0; j < data_test.numInstances(); j++) {
                double prediction=cls.classifyInstance(data_test.instance(j));
                //System.out.println(data_test.instance(j));
                double[] prob=cls.distributionForInstance(data_test.instance(j));
                //System.out.println(prediction+"--->"+prob[0]+","+prob[1]);
                //System.out.println(data_test.classAttribute().value((int) data_test.instance(j).classValue()));

                arrayNumbers[j][0] = Double.parseDouble(data_test.classAttribute().value((int) data_test.instance(j).classValue()));
                arrayNumbers[j][1] = prediction;
                arrayNumbers[j][2] = prob[1];

            }

           //System.out.println(arrayNumbers);
            Arrays.sort(arrayNumbers, new Comparator<Double[]>() {
                public int compare(Double[] s1, Double[] s2) {
                    if (s1[0] > s2[0])
                        return 1;    
                    else if (s1[0] < s2[0])
                        return -1;   
                    else {

                        return 0;
                    }
                }
            });

1 个答案:

答案 0 :(得分:0)

问题解决了:

Arrays.sort(arrayNumbers, new Comparator<double[]>() {
                @Override
                public int compare(double[] o1, double[] o2) {
                    return Double.compare(o1[2], o2[2]);
                }
            });

            if (arrayNumbers[data_test.numInstances()-1][0]==1.0){
                plus=plus+1;
            }