Libsvm java训练测试示例(也是实时)

时间:2013-12-08 19:34:47

标签: java machine-learning computer-vision libsvm

任何人都可以通过提供libsvm java示例进行培训和测试来帮助我。我是机器学习的新手,需要相同的帮助。早期提供的示例由@machine学习者有错误只给出一个类结果。我不想在之前的帖子中使用weka作为建议。

或者你可以纠正这段代码中的错误,它总是在结果中预测一个类。(我想执行多分类)。

本例由“机器学习者”

给出
import java.io.*;
import java.util.*;
import libsvm.*;

public class Test{
    public static void main(String[] args) throws Exception{

        // Preparing the SVM param
        svm_parameter param=new svm_parameter();
        param.svm_type=svm_parameter.C_SVC;
        param.kernel_type=svm_parameter.RBF;
        param.gamma=0.5;
        param.nu=0.5;
        param.cache_size=20000;
        param.C=1;
        param.eps=0.001;
        param.p=0.1;

        HashMap<Integer, HashMap<Integer, Double>> featuresTraining=new HashMap<Integer, HashMap<Integer, Double>>();
        HashMap<Integer, Integer> labelTraining=new HashMap<Integer, Integer>();
        HashMap<Integer, HashMap<Integer, Double>> featuresTesting=new HashMap<Integer, HashMap<Integer, Double>>();

        HashSet<Integer> features=new HashSet<Integer>();

        //Read in training data
        BufferedReader reader=null;
        try{
            reader=new BufferedReader(new FileReader("a1a.train"));
            String line=null;
            int lineNum=0;
            while((line=reader.readLine())!=null){
                featuresTraining.put(lineNum, new HashMap<Integer,Double>());
                String[] tokens=line.split("\\s+");
                int label=Integer.parseInt(tokens[0]);
                labelTraining.put(lineNum, label);
                for(int i=1;i<tokens.length;i++){
                    String[] fields=tokens[i].split(":");
                    int featureId=Integer.parseInt(fields[0]);
                    double featureValue=Double.parseDouble(fields[1]);
                    features.add(featureId);
                    featuresTraining.get(lineNum).put(featureId, featureValue);
                }
            lineNum++;
            }

            reader.close();
        }catch (Exception e){

        }

        //Read in test data
        try{
            reader=new BufferedReader(new FileReader("a1a.t"));
            String line=null;
            int lineNum=0;
            while((line=reader.readLine())!=null){

                featuresTesting.put(lineNum, new HashMap<Integer,Double>());
                String[] tokens=line.split("\\s+");
                for(int i=1; i<tokens.length;i++){
                    String[] fields=tokens[i].split(":");
                    int featureId=Integer.parseInt(fields[0]);
                    double featureValue=Double.parseDouble(fields[1]);
                    featuresTesting.get(lineNum).put(featureId, featureValue);
                }
            lineNum++;
            }
            reader.close();
        }catch (Exception e){

        }

        //Train the SVM model
        svm_problem prob=new svm_problem();
        int numTrainingInstances=featuresTraining.keySet().size();
        prob.l=numTrainingInstances;
        prob.y=new double[prob.l];
        prob.x=new svm_node[prob.l][];

        for(int i=0;i<numTrainingInstances;i++){
            HashMap<Integer,Double> tmp=featuresTraining.get(i);
            prob.x[i]=new svm_node[tmp.keySet().size()];
            int indx=0;
            for(Integer id:tmp.keySet()){
                svm_node node=new svm_node();
                node.index=id;
                node.value=tmp.get(id);
                prob.x[i][indx]=node;
                indx++;
            }

            prob.y[i]=labelTraining.get(i);
        }

        svm_model model=svm.svm_train(prob,param);

        for(Integer testInstance:featuresTesting.keySet()){
            HashMap<Integer, Double> tmp=new HashMap<Integer, Double>();
            int numFeatures=tmp.keySet().size();
            svm_node[] x=new svm_node[numFeatures];
            int featureIndx=0;
            for(Integer feature:tmp.keySet()){
                x[featureIndx]=new svm_node();
                x[featureIndx].index=feature;
                x[featureIndx].value=tmp.get(feature);
                featureIndx++;
            }

            double d=svm.svm_predict(model, x);

            System.out.println(testInstance+"\t"+d);
        }

    }
}

4 个答案:

答案 0 :(得分:4)

这是因为您的featuresTesting从未使用过,HashMap<Integer, Double> tmp=new HashMap<Integer, Double>();应为HashMap<Integer, Double> tmp=featuresTesting.get(testInstance);

答案 1 :(得分:2)

您可以使用javaML库对数据进行分类

这是一个带有javaML的示例代码:

   Classifier clas = new LibSVM();
        clas.buildClassifier(data);
        Dataset dataForClassification= FileHandler.loadDataset(new File(.),            0, ",");
        /* Counters for correct and wrong predictions. */
        int correct = 0, wrong = 0;
        /* Classify all instances and check with the correct class values */
        for (Instance inst : dataForClassification) {
            Object predictedClassValue = clas.classify(inst);
            Map<Object,Double> map = clas.classDistribution(inst);
            Object realClassValue = inst.classValue();
            if (predictedClassValue.equals(realClassValue))
                correct++;
            else
                wrong++;
        }

答案 2 :(得分:1)

您似乎无法理解自己在做什么,而只是从此处和那里复制代码。它可以帮助您理解基本的机器学习。例如,你应该从LIBSVM(你使用的库)的作者那里读到这个practical guide for SVM classification。你在这里得到的建议,你应该在网上进行入门机器学习课程可能更好。

让我也给你两个重要提示,如果你得到同一课程的所有成绩,可以节省你的时间:

  1. 您是否正常化数据,使所有值介于0和1之间 (或-1和1之间),线性或使用均值和 标准偏差?它似乎与您的代码不同。
  2. 您是否参数搜索C的良好值(或C中的C和gamma) RBF内核的情况)?进行交叉验证或保持不变 组?你的代码似乎没有。

答案 3 :(得分:0)

A)没有人知道你在引用。如果你不想让人们理解你所指的是什么,请给出链接。

B)你需要参加机器学习课程。 Coursera上有一个免费的。模型的输出取决于数据本身 - 并且受模型参数的影响很大。模型参数通过缩放实现,您通常需要搜索它们。您的代码不包含任何内容 - 并且您已明确表示您不熟悉机器学习。通过获得必要的背景知识,您将在几分钟内完成可以在几分钟内完成的任务。

C)有许多版本的LIBSVM for Java,你没有提供任何你正在使用的版本的指示。每个人的工作方式都有所不同。