我如何计算不同阈值的真阳性率(TPR)和假阳性率(FPR)以生成分类模型的ROC

时间:2018-10-26 01:35:43

标签: java api machine-learning weka roc

我建立了一个机器学习模型,以使用NaiveBayesMultinomial对文档进行分类。我正在使用Java Weka Api来训练和测试模型。为了评估模型性能,我想生成ROC曲线。我不了解如何针对不同的阈值计算TPR和FPR。我附上了我的源代码和示例数据集。如果有人帮助我针对生成ROC曲线的不同阈值计算TPR和FPR,将不胜感激。在此先感谢您的帮助。 我的Java代码:

    package smote;
    import java.io.File;
    import java.util.Random;
    import weka.classifiers.Classifier;
    import weka.classifiers.bayes.NaiveBayesMultinomial;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.converters.ConverterUtils.DataSource;
    import weka.filters.Filter;
    import weka.filters.unsupervised.attribute.StringToWordVector;
    public class calRoc {
        public static void main(String agrs[]) throws Exception{
            String fileRootPath = "...../DocsFIle.arff";
            Instances rawData = DataSource.read(fileRootPath);
            StringToWordVector filter = new StringToWordVector(10000);
            filter.setInputFormat(rawData);
            String[] options = { "-W", "10000", "-L", "-M", "2",
                            "-stemmer", 
            "weka.core.stemmers.IteratedLovinsStemmer", 
                            "-stopwords-handler", 
            "weka.core.stopwords.Rainbow", 
                            "-tokenizer", 
            "weka.core.tokenizers.AlphabeticTokenizer" 
                            };
            filter.setOptions(options);
            filter.setIDFTransform(true);
            filter.setStopwords(new 

      File("/Research/DoctoralReseacher/IEICE/Dataset/stopwords.txt"));
      Instances data = Filter.useFilter(rawData,filter);
      data.setClassIndex(0);        

      int numRuns = 10;
      double[] recall=new double[numRuns];
      double[] precision=new double[numRuns];
      double[] fmeasure=new double[numRuns];
      double tp, fp, fn, tn;
      String classifierName[] = { "NBM"};
      double totalPrecision,totalRecall,totalFmeasure;
     totalPrecision=totalRecall=totalFmeasure=0;
     double avgPrecision, avgRecall, avgFmeasure;
     avgPrecision=avgRecall=avgFmeasure=0;                 
     for(int run = 0; run < numRuns; run++) {
        Classifier classifier = null;
        classifier = new NaiveBayesMultinomial();
        int folds = 10;         
        Random random = new Random(1);
        data.randomize(random);
        data.stratify(folds);
        tp = fp = fn = tn = 0;
        for (int i = 0; i < folds; i++) {
            Instances trains = data.trainCV(folds, i,random);
            Instances tests = data.testCV(folds, i);
            classifier.buildClassifier(trains);             
            for (int j = 0; j < tests.numInstances(); j++) {
                Instance instance = tests.instance(j);                  
                double classValue = instance.classValue();                  
                double result = classifier.classifyInstance(instance);
                if (result == 0.0 && classValue == 0.0) {
                    tp++;
                } else if (result == 0.0 && classValue == 1.0) {
                    fp++;
                } else if (result == 1.0 && classValue == 0.0) {
                    fn++;
                } else if (result == 1.0 && classValue == 1.0) {
                    tn++;
                }
            }   
        }

        if (tn + fn > 0)
            precision[run] = tn / (tn + fn);
        if (tn + fp > 0)
            recall[run] = tn / (tn + fp);
        if (precision[run] + recall[run] > 0)
            fmeasure[run] = 2 * precision[run] * recall[run] / (precision[run] + recall[run]);
        System.out.println("The "+(run+1)+"-th run");
        System.out.println("Precision: " + precision[run]);
        System.out.println("Recall: " + recall[run]);
        System.out.println("Fmeasure: " + fmeasure[run]);
        totalPrecision+=precision[run];
        totalRecall+=recall[run];
        totalFmeasure+=fmeasure[run];

     }
     avgPrecision=totalPrecision/numRuns;
     avgRecall=totalRecall/numRuns;
     avgFmeasure=totalFmeasure/numRuns;
     System.out.println("avgPrecision: " + avgPrecision);
     System.out.println("avgRecall: " + avgRecall);
     System.out.println("avgFmeasure: " + avgFmeasure);
    }

}

Sample Dataset with few instances:

@relation 'CamelBug'

@attribute Feature string

@attribute class-att {0,1}

@data

'XQuery creates an empty out message that makes it impossible to chain 
 more processors behind it ',1

'org apache camel Message hasAttachments is buggy ',0

'unmarshal new JaxbDataFormat com foo bar returning JAXBElement ',0

'Can t get the soap header when the camel cxf endpoint working in the 
  PAYLOAD data fromat ',0

'camel jetty Exchange failures should not be returned as ',1
'Delayer not working as expected ',1
'ParallelProcessing and executor flags are ignored in Multicast 
  processor ',1 

0 个答案:

没有答案