使用SVM进行分类

时间:2017-08-08 18:49:03

标签: r machine-learning r-caret

我是机器学习和整体R的新手,我正在尝试使用内置插入符函数来训练分类SVM,以预测股票价格会上涨或下跌。您可以在下面找到我的代码:

library(caret)
library(kernlab)

# Load Data Here

sample_size = floor(0.85*nrow(DataSet))

trainSet = DataSet[1:sample_size,1:(ncol(DataSet)-1)]
testSet = DataSet[(sample_size+1):nrow(DataSet),1:(ncol(DataSet)-1)]

survived = DataSet[1:sample_size,ncol(DataSet),with = FALSE]
survivedtest = DataSet[(sample_size+1):nrow(DataSet),ncol(DataSet),with = FALSE]
survived = as.factor(survived$AAPL.res)
survivedtest = as.factor(survivedtest$AAPL.res)
setattr(survived,"levels",c("UP", "DOWN"))
setattr(survivedtest,"levels",c("UP", "DOWN"))

ctrl = trainControl(method = "LGOCV",
                     summaryFunction = twoClassSummary,
                     classProbs = TRUE,
                     index = list(TrainSet = 1:nrow(trainSet)),
                     savePredictions = TRUE,
                     verboseIter =TRUE)

set.seed(201)
sigmaRangeFull = sigest(as.matrix(trainSet))
svmRGridFull = expand.grid(sigma =  as.vector(sigmaRangeFull)[1],C = 2^(-3:4))

set.seed(476)
svmRFit = train(x = trainSet, 
                 y = survived,
                 method = "svmRadial",
                 metric = "ROC",
                 preProc = c("center", "scale"),
                 tuneGrid = svmRGridFull,
                 trControl = ctrl)
svmRFit

当我尝试训练模型时,我遇到错误:

> svmRFit = train(x = trainSet, 
+                  y = survived,
+                  method = "svmRadial",
+                  metric = "ROC",
+                  preProc = c("center", "scale"),
+                  tuneGrid = svmRGridFull,
+                  trControl = ctrl)
+ TrainSet: sigma=0.001371, C= 0.125 
+ TrainSet: sigma=0.001371, C= 0.250 
+ TrainSet: sigma=0.001371, C= 0.500 
+ TrainSet: sigma=0.001371, C= 1.000 
+ TrainSet: sigma=0.001371, C= 2.000 
+ TrainSet: sigma=0.001371, C= 4.000 
+ TrainSet: sigma=0.001371, C= 8.000 
+ TrainSet: sigma=0.001371, C=16.000 
Error in { : task 1 failed - "replacement has 1 row, data has 0"
In addition: There were 48 warnings (use warnings() to see them)

我不确定究竟是什么导致了这个问题,因为数据是由数值组成的,而y结果是一个因素。我也试过其他形式的列车控制方法并达成了同样的问题。使用的数据可在此处找到:Data

P.S。关于如何改进代码\编码实践的任何评论都非常受欢迎,因为我正在努力学习。

0 个答案:

没有答案