使用Weka来预测赛马的结果

时间:2018-06-10 11:49:26

标签: java machine-learning weka

我正在试验Weka,我正试图用它来预测赛马的结果。我有很多数据要使用,我现有的属性是:

  1. horse_id(数字)
  2. driver_id(数字)
  3. trackCondition_id(数字1-5)
  4. raceDistance(数字)
  5. track_id(数字)
  6. kmTime(数字)
  7. 地方(数字从0-10)
  8. 我使用数据进行训练和评估,其中所有属性都已知,然后我提供了一个数据集,其中place-attribute未知。但我无法弄清楚如何使用weka来获取place属性的预测值。下面的代码来自我的小型测试应用程序,非常感谢帮助让我朝着正确的方向前进。

    使用classifier.classifyInstance(...)时得到的值总是0.0

    我从运行这个程序得到的输出是:

        Relation Name:  QueryResult
    Num Instances:  2475
    Num Attributes: 7
    
         Name                      Type  Nom  Int Real     Missing      Unique  Dist
    1 hbnHorse_id                Num   0% 100%   0%     0 /  0%  1514 / 61%  1873 
    2 hbnDriver_id               Num   0% 100%   0%     0 /  0%   274 / 11%   645 
    3 hbnTrackCondition_id       Num   0%  97%   0%    84 /  3%     0 /  0%     3 
    4 raceDistance               Nom 100%   0%   0%     0 /  0%     7 /  0%    26 
    5 hbnTrack_id                Num   0%  98%   0%    53 /  2%     2 /  0%    26 
    6 kmTime                     Num   0% 100%   0%     0 /  0%    74 /  3%   246 
    7 place                      Num   0% 100%   0%     0 /  0%     1 /  0%    10 
    
    
    Results
    ======
    
    Correctly Classified Instances         909               36.7421 %
    Incorrectly Classified Instances      1565               63.2579 %
    Kappa statistic                          0.1341
    Mean absolute error                      0.1313
    Root mean squared error                  0.3162
    Relative absolute error                 84.6014 %
    Root relative squared error            113.7708 %
    Total Number of Instances             2474     
    Ignored Class Unknown Instances                  1     
    
    Relation Name:  QueryResult-weka.filters.unsupervised.attribute.NumericToNominal-R7-7
    Num Instances:  1643
    Num Attributes: 7
    
         Name                      Type  Nom  Int Real     Missing      Unique  Dist
    1 hbnHorse_id                Num   0% 100%   0%     0 /  0%  1593 / 97%  1618 
    2 hbnDriver_id               Num   0% 100%   0%     0 /  0%   238 / 14%   508 
    3 hbnTrackCondition_id       Num   0% 100%   0%     0 /  0%     0 /  0%     2 
    4 raceDistance               Nom 100%   0%   0%     0 /  0%     0 /  0%     5 
    5 startMethod                Nom 100%   0%   0%     0 /  0%     0 /  0%     2 
    6 kmTime                     Num   0% 100%   0%     0 /  0%     0 /  0%     1 
    7 place                      Nom 100%   0%   0%     1 /  0%     0 /  0%     1 
    

    源代码:

       package org.peanuts.weka.core;
    
    import weka.classifiers.Classifier;
    import weka.classifiers.bayes.NaiveBayes;
    import weka.classifiers.evaluation.Evaluation;
    import weka.classifiers.trees.J48;
    import weka.core.Instances;
    import weka.experiment.InstanceQuery;
    import weka.filters.Filter;
    import weka.filters.unsupervised.attribute.NumericToNominal;
    
    import java.io.File;
    
    public class WekaTest {
    
        public static void main(String args[]) {
            try {
                InstanceQuery query = new InstanceQuery();
                query.setCustomPropsFile(new File("C:\\codez\\WekaPlayground\\src\\main\\resources\\props\\DatabaseUtils.props"));
                query.setUsername("wekaplayground");
                query.setPassword("password");
    //            query.setQuery("select s.hbnHorse_id, s.hbnDriver_id, r.hbnTrackCondition_id, r.raceDistance, r.startMethod, r.hbnTrack_id, s.kmTime, s.place\n" +
                query.setQuery("select s.hbnHorse_id, s.hbnDriver_id, r.hbnTrackCondition_id, r.raceDistance, r.hbnTrack_id, s.kmTime, s.place\n" +
                        "from hbnstart s, hbnrace r\n" +
                        "where s.hbnRace_id = r.id\n" +
                        "and r.hasResults = 1 " +
    //                    "and s.place > 0 " +
                        "and s.id < 22000");
                // You can declare that your train set is sparse
                // query.setSparseData(true);
    
                Instances train = query.retrieveInstances();
                if (train.classIndex() == -1)
                    train.setClassIndex(train.numAttributes() - 1);
    
                query.setQuery("select s.hbnHorse_id, s.hbnDriver_id, r.hbnTrackCondition_id, r.raceDistance, r.hbnTrack_id, s.kmTime, s.place\n" +
                        "from hbnstart s, hbnrace r\n" +
                        "where s.hbnRace_id = r.id\n" +
                        "and r.hasResults = 1 " +
    //                    "and s.place > 0 " +
                        "and s.id > 22000");
                Instances test = query.retrieveInstances();
                if (test.classIndex() == -1)
                    test.setClassIndex(test.numAttributes() - 1);
                System.out.println(test.toSummaryString());
    
                Classifier classifier = createClassifier();
    
                classifier.buildClassifier(convertToNominal(train));
                // evaluate classifier and print some statistics
                Evaluation eval = new Evaluation(convertToNominal(train));
                eval.evaluateModel(classifier, convertToNominal(test));
                System.out.println(eval.toSummaryString("\nResults\n======\n", false));
    
                if (train.classIndex() == -1)
                    train.setClassIndex(train.numAttributes() - 1);
    
                query.setQuery("select s.hbnHorse_id, s.hbnDriver_id, r.hbnTrackCondition_id, r.raceDistance, r.startMethod, s.kmTime, s.place\n" +
                        "from hbnstart s, hbnrace r\n" +
                        "where s.hbnRace_id = r.id\n" +
                        "and r.hasResults = 0");
    
                Instances unlabeled = query.retrieveInstances();
                // set class attribute
                unlabeled.setClassIndex(unlabeled.numAttributes() - 1);
                unlabeled = convertToNominal(unlabeled);
    
                // create copy
                Instances labeled = new Instances(unlabeled);
    
                System.out.println(labeled.toSummaryString());
    
                // label instances
                for (int i = 0; i < unlabeled.numInstances(); i++) {
                    double clsLabel = classifier.classifyInstance(unlabeled.instance(i));
                    labeled.instance(i).setClassValue(clsLabel);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    
        public static Instances convertToNominal(Instances instances) throws Exception {
            instances.get(0).setClassMissing();
            NumericToNominal convert = new NumericToNominal();
            String[] numToNomOptions = new String[2];
            numToNomOptions[0] = "-R";
            numToNomOptions[1] = "7-7";  //range of variables to make numeric
    
            convert.setOptions(numToNomOptions);
            convert.setInputFormat(instances);
    
            return Filter.useFilter(instances, convert);
        }
    
        private static Classifier createClassifier() throws Exception {
            String[] options = new String[1];
            options[0] = "-U";            // unpruned classifier
            J48 classifier = new J48();         // new instance of classifier
            classifier.setOptions(options);
            return classifier;
        }
    }
    

0 个答案:

没有答案