Java,weka LibSVM无法正确预测

时间:2017-05-13 01:12:27

标签: java machine-learning regression weka libsvm

我在我的java代码中使用了libSVM和weka。我想做一个回归。以下是我的代码,

public static void predict() {

    try {
        DataSource sourcePref1 = new DataSource("train_pref2new.arff");
        Instances trainData = sourcePref1.getDataSet();

        DataSource sourcePref2 = new DataSource("testDatanew.arff");
        Instances testData = sourcePref2.getDataSet();

        if (trainData.classIndex() == -1) {
            trainData.setClassIndex(trainData.numAttributes() - 2);
        }

        if (testData.classIndex() == -1) {
            testData.setClassIndex(testData.numAttributes() - 2);
        }

        LibSVM svm1 = new LibSVM();

        String options = ("-S 3 -K 2 -D 3 -G 1000.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1");
        String[] optionsArray = options.split(" ");
        svm1.setOptions(optionsArray);

        svm1.buildClassifier(trainData);

        for (int i = 0; i < testData.numInstances(); i++) {

            double pref1 = svm1.classifyInstance(testData.instance(i));                
            System.out.println("predicted value : " + pref1);

        }

    } catch (Exception ex) {
        Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
    }
}

但是我从这段代码得到的预测值与我通过使用Weka GUI获得的预测值不同。

示例: 下面是我为java代码和weka GUI提供的单个测试数据。

Java代码将值预测为 1.9064516129032265 ,而Weka GUI的预测值 10.043 。我对Java代码和Weka GUI使用相同的训练数据集和相同的参数。

我希望你理解我的问题。任何人都可以告诉我我的代码有什么问题吗?

1 个答案:

答案 0 :(得分:2)

您使用错误的算法执行SVM回归。 LibSVM用于分类。你想要的是SMOreg,它是一个特定的SVM用于回归。

下面是一个完整的示例,演示了如何使用Weka Explorer GUI和Java API来使用SMOreg。对于数据,我将使用Weka发行版附带的cpu.arff数据文件。请注意,我将此文件用于训练和测试,但理想情况下,您将拥有单独的数据集。

使用Weka Explorer GUI

  1. 打开WEKA Explorer GUI,单击Preprocess选项卡,单击Open File,然后打开Weka发行版中应存在的cpu.arff文件。在我的系统上,该文件位于weka-3-8-1/data/cpu.arff下。资源管理器窗口应如下所示:
  2. Weka Explorer - Choosing the file

    1. 点击Classify标签。它应该被称为“预测”,因为你可以在这里进行分类和回归。在Classifier下,点击Choose,然后选择weka - &gt; classifiers - &gt; functions - &gt; SMOreg,如下所示。
    2. Weka Explorer - Choosing the regression algorithm

      1. 现在构建回归模型并对其进行评估。在Test Options下选择Use training set,以便我们的训练集也用于测试(如上所述,这不是理想的方法)。现在按Start,结果应如下所示:
      2. Weka Explorer - Results from testing

        记下RMSE值(74.5996)。我们将在Java代码实现中重新审视它。

        使用Java API

        下面是一个完整的Java程序,它使用Weka API复制早期在Weka Explorer GUI中显示的结果。

        import weka.classifiers.functions.SMOreg;
        import weka.classifiers.Evaluation;
        import weka.core.Instance;
        import weka.core.Instances;
        import weka.core.converters.ConverterUtils.DataSource;
        
        public class Tester {
        
            /**
             * Builds a regression model using SMOreg, the SVM for regression, and 
             * evaluates it with the Evalution framework.
             */
            public void buildAndEvaluate(String trainingArff, String testArff) throws Exception {
        
                System.out.printf("buildAndEvaluate() called.\n");
        
                // Load the training and test instances.
                Instances trainingInstances = DataSource.read(trainingArff);
                Instances testInstances = DataSource.read(testArff);
        
                // Set the true value to be the last field in each instance.
                trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
                testInstances.setClassIndex(testInstances.numAttributes()-1);
        
                // Build the SMOregression model.
                SMOreg smo = new SMOreg();
                smo.buildClassifier(trainingInstances);
        
                // Use Weka's evaluation framework.
                Evaluation eval = new Evaluation(trainingInstances);
                eval.evaluateModel(smo, testInstances);
        
                // Print the options that were used in the ML algorithm.
                String[] options = smo.getOptions();
                System.out.printf("Options used:\n");
                for (String option : options) {
                    System.out.printf("%s ", option);
                }
                System.out.printf("\n\n");
        
                // Print the algorithm details.
                System.out.printf("Algorithm:\n %s\n", smo.toString());
        
                // Print the evaluation results.
                System.out.printf("%s\n", eval.toSummaryString("\nResults\n=====\n", false));
            }
        
            /**
             * Builds a regression model using SMOreg, the SVM for regression, and 
             * tests each data instance individually to compute RMSE.
             */
            public void buildAndTestEachInstance(String trainingArff, String testArff) throws Exception {
        
                System.out.printf("buildAndTestEachInstance() called.\n");
        
                // Load the training and test instances.
                Instances trainingInstances = DataSource.read(trainingArff);
                Instances testInstances = DataSource.read(testArff);
        
                // Set the true value to be the last field in each instance.
                trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
                testInstances.setClassIndex(testInstances.numAttributes()-1);
        
                // Build the SMOregression model.
                SMOreg smo = new SMOreg();
                smo.buildClassifier(trainingInstances);
        
                int numTestInstances = testInstances.numInstances();
        
                // This variable accumulates the squared error from each test instance.
                double sumOfSquaredError = 0.0;
        
                // Loop over each test instance.
                for (int i = 0; i < numTestInstances; i++) {
        
                    Instance instance = testInstances.instance(i);
        
                    double trueValue = instance.value(testInstances.classIndex());
                    double predictedValue = smo.classifyInstance(instance);
        
                    // Uncomment the next line to see every prediction on the test instances.
                    //System.out.printf("true=%10.5f, predicted=%10.5f\n", trueValue, predictedValue);
        
                    double error = trueValue - predictedValue;
                    sumOfSquaredError += (error * error);
                }
        
                // Print the RMSE results.
                double rmse = Math.sqrt(sumOfSquaredError / numTestInstances);
                System.out.printf("RMSE = %10.5f\n", rmse);
            }
        
            public static void main(String argv[]) throws Exception {
        
                Tester classify = new Tester();
                classify.buildAndEvaluate("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
                classify.buildAndTestEachInstance("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
            }
        }
        

        我编写了两个函数来训练SMOreg模型,并通过对训练数据运行预测来评估模型。

        • buildAndEvaluate()使用Weka评估模型 Evaluation框架运行一套测试以获得完全相同的结果 结果作为Explorer GUI。值得注意的是,它产生RMSE值。

        • buildAndTestEachInstance()显式评估模型 循环遍历每个测试实例,进行预测,计算 错误,并计算整体RMSE。请注意,此RMSE匹配 来自buildAndEvaluate()的那个,而后者与那个相匹配 来自Explorer GUI。

        以下是编译和运行程序的结果。

        prompt> javac -cp weka.jar Tester.java
        
        prompt> java -cp .:weka.jar Tester
        
        buildAndEvaluate() called.
        Options used:
        -C 1.0 -N 0 -I weka.classifiers.functions.supportVector.RegSMOImproved -T 0.001 -V -P 1.0E-12 -L 0.001 -W 1 -K weka.classifiers.functions.supportVector.PolyKernel -E 1.0 -C 250007 
        
        Algorithm:
         SMOreg
        
        weights (not support vectors):
         +       0.01   * (normalized) MYCT
         +       0.4321 * (normalized) MMIN
         +       0.1847 * (normalized) MMAX
         +       0.1175 * (normalized) CACH
         +       0.0973 * (normalized) CHMIN
         +       0.0235 * (normalized) CHMAX
         -       0.0168
        
        
        
        Number of kernel evaluations: 21945 (93.081% cached)
        
        Results
        =====
        
        Correlation coefficient                  0.9044
        Mean absolute error                     31.7392
        Root mean squared error                 74.5996
        Relative absolute error                 33.0908 %
        Root relative squared error             46.4953 %
        Total Number of Instances              209     
        
        buildAndTestEachInstance() called.
        RMSE =   74.59964