我在我的java代码中使用了libSVM和weka。我想做一个回归。以下是我的代码,
public static void predict() {
try {
DataSource sourcePref1 = new DataSource("train_pref2new.arff");
Instances trainData = sourcePref1.getDataSet();
DataSource sourcePref2 = new DataSource("testDatanew.arff");
Instances testData = sourcePref2.getDataSet();
if (trainData.classIndex() == -1) {
trainData.setClassIndex(trainData.numAttributes() - 2);
}
if (testData.classIndex() == -1) {
testData.setClassIndex(testData.numAttributes() - 2);
}
LibSVM svm1 = new LibSVM();
String options = ("-S 3 -K 2 -D 3 -G 1000.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1");
String[] optionsArray = options.split(" ");
svm1.setOptions(optionsArray);
svm1.buildClassifier(trainData);
for (int i = 0; i < testData.numInstances(); i++) {
double pref1 = svm1.classifyInstance(testData.instance(i));
System.out.println("predicted value : " + pref1);
}
} catch (Exception ex) {
Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
}
}
但是我从这段代码得到的预测值与我通过使用Weka GUI获得的预测值不同。
示例: 下面是我为java代码和weka GUI提供的单个测试数据。
Java代码将值预测为 1.9064516129032265 ,而Weka GUI的预测值 10.043 。我对Java代码和Weka GUI使用相同的训练数据集和相同的参数。
我希望你理解我的问题。任何人都可以告诉我我的代码有什么问题吗?
答案 0 :(得分:2)
您使用错误的算法执行SVM回归。 LibSVM用于分类。你想要的是SMOreg,它是一个特定的SVM用于回归。
下面是一个完整的示例,演示了如何使用Weka Explorer GUI和Java API来使用SMOreg。对于数据,我将使用Weka发行版附带的cpu.arff
数据文件。请注意,我将此文件用于训练和测试,但理想情况下,您将拥有单独的数据集。
使用Weka Explorer GUI
Preprocess
选项卡,单击Open File
,然后打开Weka发行版中应存在的cpu.arff
文件。在我的系统上,该文件位于weka-3-8-1/data/cpu.arff
下。资源管理器窗口应如下所示:Classify
标签。它应该被称为“预测”,因为你可以在这里进行分类和回归。在Classifier
下,点击Choose
,然后选择weka
- &gt; classifiers
- &gt; functions
- &gt; SMOreg
,如下所示。Test Options
下选择Use training set
,以便我们的训练集也用于测试(如上所述,这不是理想的方法)。现在按Start
,结果应如下所示:记下RMSE值(74.5996)。我们将在Java代码实现中重新审视它。
使用Java API
下面是一个完整的Java程序,它使用Weka API复制早期在Weka Explorer GUI中显示的结果。
import weka.classifiers.functions.SMOreg;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
public class Tester {
/**
* Builds a regression model using SMOreg, the SVM for regression, and
* evaluates it with the Evalution framework.
*/
public void buildAndEvaluate(String trainingArff, String testArff) throws Exception {
System.out.printf("buildAndEvaluate() called.\n");
// Load the training and test instances.
Instances trainingInstances = DataSource.read(trainingArff);
Instances testInstances = DataSource.read(testArff);
// Set the true value to be the last field in each instance.
trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
testInstances.setClassIndex(testInstances.numAttributes()-1);
// Build the SMOregression model.
SMOreg smo = new SMOreg();
smo.buildClassifier(trainingInstances);
// Use Weka's evaluation framework.
Evaluation eval = new Evaluation(trainingInstances);
eval.evaluateModel(smo, testInstances);
// Print the options that were used in the ML algorithm.
String[] options = smo.getOptions();
System.out.printf("Options used:\n");
for (String option : options) {
System.out.printf("%s ", option);
}
System.out.printf("\n\n");
// Print the algorithm details.
System.out.printf("Algorithm:\n %s\n", smo.toString());
// Print the evaluation results.
System.out.printf("%s\n", eval.toSummaryString("\nResults\n=====\n", false));
}
/**
* Builds a regression model using SMOreg, the SVM for regression, and
* tests each data instance individually to compute RMSE.
*/
public void buildAndTestEachInstance(String trainingArff, String testArff) throws Exception {
System.out.printf("buildAndTestEachInstance() called.\n");
// Load the training and test instances.
Instances trainingInstances = DataSource.read(trainingArff);
Instances testInstances = DataSource.read(testArff);
// Set the true value to be the last field in each instance.
trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
testInstances.setClassIndex(testInstances.numAttributes()-1);
// Build the SMOregression model.
SMOreg smo = new SMOreg();
smo.buildClassifier(trainingInstances);
int numTestInstances = testInstances.numInstances();
// This variable accumulates the squared error from each test instance.
double sumOfSquaredError = 0.0;
// Loop over each test instance.
for (int i = 0; i < numTestInstances; i++) {
Instance instance = testInstances.instance(i);
double trueValue = instance.value(testInstances.classIndex());
double predictedValue = smo.classifyInstance(instance);
// Uncomment the next line to see every prediction on the test instances.
//System.out.printf("true=%10.5f, predicted=%10.5f\n", trueValue, predictedValue);
double error = trueValue - predictedValue;
sumOfSquaredError += (error * error);
}
// Print the RMSE results.
double rmse = Math.sqrt(sumOfSquaredError / numTestInstances);
System.out.printf("RMSE = %10.5f\n", rmse);
}
public static void main(String argv[]) throws Exception {
Tester classify = new Tester();
classify.buildAndEvaluate("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
classify.buildAndTestEachInstance("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
}
}
我编写了两个函数来训练SMOreg模型,并通过对训练数据运行预测来评估模型。
buildAndEvaluate()
使用Weka评估模型
Evaluation
框架运行一套测试以获得完全相同的结果
结果作为Explorer GUI。值得注意的是,它产生RMSE值。
buildAndTestEachInstance()
显式评估模型
循环遍历每个测试实例,进行预测,计算
错误,并计算整体RMSE。请注意,此RMSE匹配
来自buildAndEvaluate()
的那个,而后者与那个相匹配
来自Explorer GUI。
以下是编译和运行程序的结果。
prompt> javac -cp weka.jar Tester.java
prompt> java -cp .:weka.jar Tester
buildAndEvaluate() called.
Options used:
-C 1.0 -N 0 -I weka.classifiers.functions.supportVector.RegSMOImproved -T 0.001 -V -P 1.0E-12 -L 0.001 -W 1 -K weka.classifiers.functions.supportVector.PolyKernel -E 1.0 -C 250007
Algorithm:
SMOreg
weights (not support vectors):
+ 0.01 * (normalized) MYCT
+ 0.4321 * (normalized) MMIN
+ 0.1847 * (normalized) MMAX
+ 0.1175 * (normalized) CACH
+ 0.0973 * (normalized) CHMIN
+ 0.0235 * (normalized) CHMAX
- 0.0168
Number of kernel evaluations: 21945 (93.081% cached)
Results
=====
Correlation coefficient 0.9044
Mean absolute error 31.7392
Root mean squared error 74.5996
Relative absolute error 33.0908 %
Root relative squared error 46.4953 %
Total Number of Instances 209
buildAndTestEachInstance() called.
RMSE = 74.59964