我正在尝试训练SVM来对两个螺旋数据进行分类。
我的输入是3列CSV文件,前两列是螺旋上的点(未标准化)的(x,y)坐标,第三列是该点所属的螺旋(类)
我首先规范化CSV文件,使前两列介于0和1之间(第三列保持不变)。
然后我按如下方式创建和训练SVM
CSVNeuralDataSet trainingSet = new CSVNeuralDataSer(normaliseCSV("/path/to/data/file"), 2, 1, false);
SVM svm = new SVM(2, false);
final SVMSearchTrain train = new SVMSearchTrain(svm, trainingSet);
int epoch = 0;
do {
train.iteration();
System.out.println("Epoch $: " + epoch + " Error: " + train.getError());
epoch++;
} while(train.getError() > 0.01);
train.finishTraining();
然而,do ... while循环最终是一个无限循环,因为训练误差大约为0.4并且它永远不会改变。
数据集包含大约200个样本,并且只有两个类(0和1)。
有谁能告诉我为什么会失败?
编辑:Here is a pastebin link大约10%的培训数据。
答案 0 :(得分:2)
精彩的问题。您的问题是SVM无法构建螺旋数据的分离曲线。我建议你尝试规范化技巧,但不是根据X,Y坐标作为直线进行标准化,而是switch到polar system of coordinates。并考虑阿基米德螺旋,对数螺旋等。请看图片。螺旋数据要求SVM构建一些功能,将1类和2类之间的数据分开,我非常确定SVM不是一件容易的事。但是如果你能找到从螺旋数据表示切换到线性的方法,那么SVM将需要在两条曲线之间建立分离,这更容易。