我正在尝试对dataset进行分类!在此数据集中,第一列是理想结果,其他20列是输入。
在我这里出现的问题是,在数据集上训练的SVM(在这种情况下,80%用于训练)显示出0.0的训练误差,但始终将1.0预测为结果。
我将集合分为两部分,一个用于训练(数据的80%),另一个用于分类的20%。该数据是两个短时间序列的RSI值(一个2个周期和一个14个周期)的串联。
为什么SVM会有这种行为?我可以采取一些措施来避免这种情况吗?我认为训练误差为0.0意味着,在训练集上SVM不会再犯错误。从结果来看,这似乎是错误的。
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.encog.Encog;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.SVMTrain;
public class SVMTest {
public static void main(String[] args) {
List<String> lines = readFile("/home/wens/mlDataSet.csv");
double[][] trainingSetData = getInputData(lines, 0, lines.size()/10*8);
double[][] trainingIdeal = getIdeal(lines, 0, lines.size()/10*8);
MLDataSet trainingSet = new BasicMLDataSet(trainingSetData, trainingIdeal);
double[][] classificationSetData = getInputData(lines, lines.size()/10*8, lines.size());
double[][] classificationIdeal = getIdeal(lines, lines.size()/10*8, lines.size());
MLDataSet classificationSet = new BasicMLDataSet(classificationSetData, classificationIdeal);
SVM svm = new SVM(20,false);
final SVMTrain train = new SVMTrain(svm, trainingSet);
train.iteration();
train.finishTraining();
System.out.println("training error: " + train.getError());
System.out.println("SVM Results:");
for(MLDataPair pair: classificationSet ) {
final MLData output = svm.compute(pair.getInput());
System.out.println("actual: " + output.getData(0) + "\tideal=" + pair.getIdeal().getData(0));
}
Encog.getInstance().shutdown();
}
private static List<String> readFile(String filepath){
List<String> res = new ArrayList<>();
try {
File f = new File(filepath);
BufferedReader b = new BufferedReader(new FileReader(f));
String readLine = "";
while ((readLine = b.readLine()) != null) {
res.add(readLine);
}
} catch (IOException e) {
e.printStackTrace();
}
return res;
}
private static double[][] getInputData(List<String> lines, int start, int end){
double[][] res = new double[end-start][20];
int cnt = 0;
for(int i=start; i<end; i++){
String[] tmp = lines.get(i).split("\t");
for(int j=1; j<tmp.length; j++){
res[cnt][j-1] = Double.parseDouble(tmp[j]);
}
cnt++;
}
return res;
}
private static double[][] getIdeal(List<String> lines, int start, int end){
double[][] res = new double[end-start][1];
int cnt = 0;
for(int i=start; i<end; i++){
String[] tmp = lines.get(i).split("\t");
res[cnt][0] = Double.parseDouble(tmp[0]);
cnt++;
}
return res;
}
}