尽管有0.0训练错误,但经过训练的SVM只能输出1.0结果

时间:2019-02-17 20:12:21

标签: java svm encog

我正在尝试对dataset进行分类!在此数据集中,第一列是理想结果,其他20列是输入。

在我这里出现的问题是,在数据集上训练的SVM(在这种情况下,80%用于训练)显示出0.0的训练误差,但始终将1.0预测为结果。

我将集合分为两部分,一个用于训练(数据的80%),另一个用于分类的20%。该数据是两个短时间序列的RSI值(一个2个周期和一个14个周期)的串联。

为什么SVM会有这种行为?我可以采取一些措施来避免这种情况吗?我认为训练误差为0.0意味着,在训练集上SVM不会再犯错误。从结果来看,这似乎是错误的。

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.encog.Encog;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.SVMTrain;

public class SVMTest {

    public static void main(String[] args) {
        List<String> lines = readFile("/home/wens/mlDataSet.csv");
        double[][] trainingSetData = getInputData(lines, 0, lines.size()/10*8);
        double[][] trainingIdeal = getIdeal(lines, 0, lines.size()/10*8);
        MLDataSet trainingSet = new BasicMLDataSet(trainingSetData, trainingIdeal);
        double[][] classificationSetData = getInputData(lines, lines.size()/10*8, lines.size());
        double[][] classificationIdeal = getIdeal(lines, lines.size()/10*8, lines.size());
        MLDataSet classificationSet = new BasicMLDataSet(classificationSetData, classificationIdeal);

        SVM svm = new SVM(20,false);
        final SVMTrain train = new SVMTrain(svm, trainingSet);
        train.iteration();
        train.finishTraining();
        System.out.println("training error: " + train.getError());

        System.out.println("SVM Results:");
        for(MLDataPair pair: classificationSet ) {
            final MLData output = svm.compute(pair.getInput());
            System.out.println("actual: " + output.getData(0) + "\tideal=" + pair.getIdeal().getData(0));
        }

        Encog.getInstance().shutdown();
    }

    private static List<String> readFile(String filepath){
        List<String> res = new ArrayList<>();
        try {
            File f = new File(filepath);
            BufferedReader b = new BufferedReader(new FileReader(f));
            String readLine = "";
            while ((readLine = b.readLine()) != null) {
                res.add(readLine);
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
        return res;
    }

    private static double[][] getInputData(List<String> lines, int start, int end){
        double[][] res = new double[end-start][20];
        int cnt = 0;
        for(int i=start; i<end; i++){
            String[] tmp = lines.get(i).split("\t");
            for(int j=1; j<tmp.length; j++){
                res[cnt][j-1] = Double.parseDouble(tmp[j]);
            }
            cnt++;
        }
        return res;
    }

    private static double[][] getIdeal(List<String> lines, int start, int end){
        double[][] res = new double[end-start][1];
        int cnt = 0;
        for(int i=start; i<end; i++){
            String[] tmp = lines.get(i).split("\t");
            res[cnt][0] = Double.parseDouble(tmp[0]);
            cnt++;
        }
        return res;
    }
}

0 个答案:

没有答案