堆叠来自不同包的模型

时间:2017-11-01 17:13:40

标签: r merge classification svm r-caret

我正在使用e1071软件包中的SVM进行破产预测(分类)。为了改善我的结果,我想将它与插入包中的随机森林结合起来。首先,我将展示我的RF模型,然后我将展示SVM模型。之后,我将展示我尝试组合(堆叠)它们。

提前抱歉凌乱的代码。我是这一切的新手。

RF模型(插入符号包)

set.seed(123)
model.rf <- train(as.factor(year.of.bankruptcy) ~ ., method = "rf", data = training.set)
predict.rf <- predict(model.rf, testing.set[,-1])

RF模型精度

confusionMatrix(predict.rf, testing.set$year.of.bankruptcy, mode="everything")$overall[1]

- &GT;这给了我模型的准确性: 准确性 0.7166667

SVM(e1071包)

set.seed(123)
model1<-function(k,d,c,g){
  model <-svm(year.of.bankruptcy ~., data = training.set, type = "C-classification", kernel = k, degree= d, cost =c, gamma =g)
  1<-testing.set[,-1]
  2<-testing.set$year.of.bankruptcy
  model_prediction <- predict(model, 1)
  result<-table(model_prediction, 2)
  return(result)
}

result<-model1(k="radial", d=2, c=2,g=0.1)
result
classAgreement(tab=result, match.names = FALSE)
classAgreement(tab=result, match.names = FALSE)$diag

- &GT;这给了我模型的准确性: [1] 0.7466667

将模型堆叠在一起

predictDF <- data.frame(predict.rf, classAgreement(tab=result, match.names = FALSE)$diag, class = testing.set$year.of.bankruptcy)
predictDF_bc <- ROSE(class ~.,predictDF, N=300, p=0.5, seed=12)$data

set.seed(123)

combined.model.gbm <- train(as.factor(class) ~ ., method = "gbm", data = predictDF_bc, distribution = "bernoulli")
combined.prediction.gbm <- predict(combined.model.gbm, predictDF)

评估模型

confusionMatrix(combined.prediction.gbm, testing.set$year.of.bankruptcy)$overall[1]`enter code here`

- &GT;这给了我堆叠模型的准确性: 准确性 0.7166667

如您所见,组合模型不考虑SVM。由于我的综合得分低于我的SVM得分。对我能做什么的任何建议?

  > dput(training.set[sample(1:nrow(training.set), 50),])

structure(list(year.of.bankruptcy = c(-1, -1, -1, -1, -1, -1, -1, 1, -1, 
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
-1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
-1, -1, -1, 1, -1, -1, -1, -1, -1), liquidity_1 = c(90.0695931477516, 
85.4305617311398, 76.2455934195065, 4.34688111280157, 159.020111900801, 
104.569486404834, 58.3391003460208, 42.0907973873116, 101.121495327103, 
94.3786295005807, 47.7552816901408, 125.702184574231, 125.763725699637, 
106.584557081952, 0, 143.6203466894, 82.5245328673209, 35.296442687747, 
8.85744561490993, 12.4657534246575, 128.164489183979, 133.131146034372, 
92.0528568769775, 22.8177150192555, 100.237812128419, 40.0340715502555, 
91.360486091332, 129.123757904246, 92.9165443694355, 130.999694283094, 
22.2526106414719, 101.714770797963, 93.1704260651629, 46.6268560361524, 
125.838858750251, 106.076759061834, 86.787017476474, 84.7495991700462, 
42.1171171171171, 68.806311160926, 93.1549687282835, 104.196667352397, 
47.0834921845215, 77.8816199376947, 76.9065981148243, 90.988709507228, 
98.9704873026767, 163.446031970576, 113.768115942029, 92.9742188833874
), profmarg_1 = c(241.916488222698, 215.221579961464, 633.490011750881, 
0, 173.627703009224, 193.164652567976, 3.32179930795848, 82.390221819828, 
131.842456608812, 102.044134727062, 0, 7.2447614801605, 113.608203375347, 
169.208905731881, 0, 179.866439329355, 250.396558677242, 48.0632411067194, 
0, 12.8082191780822, 0.963803812379525, 0, 452.279918109064, 
0, 16.4090368608799, 11.4449434722007, 173.331434539068, 240.216802168022, 
307.709617454261, 179.883827575665, 281.476877175535, 539.609507640068, 
183.12447786132, 31.8431245965139, 151.215591721921, 95.3980099502487, 
259.97695410025, 174.073375459776, 11.986986986987, 160.94322541708, 
119.110493398193, 428.03949804567, 194.624475791079, 325.877466251298, 
37.2322193658955, 245.71066793289, 207.343857240906, 22.49257320696, 
43.6487638533674, 97.4987194809629), drmarg = c(1.46603230803275, 
12.6575304731079, -0.798553144129104, 53.3333333333333, 11.8097892353249, 
29.1893259137473, 60.4166666666667, -23.041601255887, 1.21518987341772, 
6.1535019019915, 82.4626865671642, -4, 4.47536667920271, -3.69540873460246, 
65.3543307086614, 6.46738701790362, -3.63987759703656, 0.575657894736842, 
70.2460850111857, 45.4545454545455, -724.444444444444, 18.809947734191, 
3.22818215293973, 92.9292929292929, 6.52173913043478, 50.8680555555556, 
4.88031987730733, 19.9684115523466, 1.1446376903755, 13.3729821580289, 
1.22027317479027, 4.0986955838441, -3.29607664233577, 73.4414597060314, 
3.95960669678448, 28.6645874681032, 17.2991867598802, 10.8455534851063, 
55.741127348643, 8.98526582981339, 7.36196319018405, 4.85894170231172, 
10.4852855193919, -1.6774275224712, 16.3210702341137, 2.47726693294808, 
5.64784053156146, 59.622641509434, 11.0029211295034, 50.5987773218323
), ROA = c(3.546573875803, 27.2417370683267, -5.05875440658049, 
6.52032166920235, 20.5050657795252, 87.1601208459215, 2.00692041522491, 
-18.9840263855655, 1.60213618157543, 6.38792102206736, 9.72711267605634, 
-0.356665180561748, 5.08438367870113, -6.25296068214116, 3.53041259038707, 
11.6510372264848, -9.11412824304342, 0.276679841897233, 5.87171975316337, 
5.82191780821918, -6.98222317412722, 30.0983365499495, 14.6845337800112, 
11.8100128369705, 1.07015457788347, 6.05028134840741, 8.45912845343207, 
47.9674796747967, 3.52216025829175, 24.0599205136044, 4.37593237195425, 
22.1392190152801, -6.0359231411863, 23.3860555196901, 5.98754269640346, 
35.9275053304904, 46.5719224121375, 18.9380364047911, 6.68168168168168, 
19.5326981937319, 9.17303683113273, 20.7981896729068, 20.5108654212734, 
-5.50363447559709, 10.4541559554413, 6.15173578136541, 12.4456646076413, 
13.4106662894327, 4.81670929241262, 51.5793068123613), debt_ratio_1 = c(75.6423982869379, 
157.077219504965, 180.975323149236, 88.958921973484, 96.869801905338, 
93.0513595166163, 78.6159169550173, 131.707948004915, 132.096128170895, 
100.789779326365, 28.080985915493, 48.1497993758359, 85.6868190557573, 
85.5518711511132, 75.4714305969091, 92.0431940892299, 123.551552628041, 
43.8735177865613, 89.2601134451162, 69.0547945205479, 29.727993146284, 
110.265600588181, 154.662199888331, 54.2362002567394, 20.9274673008323, 
79.0666460172423, 150.536409380044, 101.355013550135, 145.827218471774, 
45.2155304188322, 123.222277473894, 134.90662139219, 123.141186299081, 
41.7043253712072, 66.2648181635523, 26.5813788201848, 95.1411561359708, 
105.191926813166, 7.60760760760761, 179.997413458637, 92.7032661570535, 
121.49763423164, 96.3400686237133, 129.823468328141, 39.502999143102, 
136.213991769547, 119.01166781057, 84.8210496534163, 8.99403239556692, 
113.957657503842), young = c(1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), medium_age = c(0, 
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 
0, 1, 0, 1, 0, 0, 0), old = c(0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 
1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 
1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0), agriculture = c(0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0), offshore_shipping = c(0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0), transport = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), manufacturing = c(0, 
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0), telecom_it_tech = c(0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0), electricity = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), construction = c(0, 
0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 
1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 1), wholesale_retail = c(0, 0, 1, 0, 1, 0, 
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 
0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 
0, 0), finance = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), change_output = c(0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549), oil_price_dummy = c(0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0), fish_price_dummy = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0.180737819481274, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.180737819481274, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0)), .Names = c("year.of.bankruptcy", "liquidity_1", 
"profmarg_1", "drmarg", "ROA", "debt_ratio_1", "young", "medium_age", 
"old", "agriculture", "offshore_shipping", "transport", "manufacturing", 
"telecom_it_tech", "electricity", "construction", "wholesale_retail", 
"finance", "change_output", "oil_price_dummy", "fish_price_dummy"
), row.names = c(19L, 49L, 25L, 53L, 56L, 3L, 31L, 50L, 58L, 
62L, 51L, 24L, 35L, 29L, 6L, 44L, 12L, 2L, 15L, 42L, 39L, 30L, 
27L, 40L, 26L, 41L, 21L, 22L, 11L, 63L, 32L, 60L, 36L, 52L, 1L, 
14L, 37L, 34L, 8L, 43L, 4L, 10L, 9L, 54L, 59L, 64L, 23L, 20L, 
17L, 13L), class = "data.frame")

1 个答案:

答案 0 :(得分:1)

使用caretEnsemble库可以非常轻松地完成堆叠模型 这是一个例子:

library(mlbench) #for the data set
library(caret)
library(caretEnsemble)

data(PimaIndiansDiabetes)
set.seed(123)

列出要使用的算法:

algorithmList <- c("svmRadial", "rf" ) 

如果您想在每个模型中指定调整参数,请在tuneList函数中使用caretList参数:

trainControl中的

savePredictions = "final"classProbs = TRUE是强制性的

control <- trainControl(method = "repeatedcv", number = 4, repeats = 3, 
                        savePredictions = "final" , classProbs = TRUE)

models <- caretList(diabetes ~ ., data = PimaIndiansDiabetes, trControl = control,
                     metric = "Kappa", methodList = algorithmList)

results <- resamples(models)

summary(results)
#output
Call:
summary.resamples(object = results)

Models: svmRadial, rf 
Number of resamples: 12 

Accuracy 
               Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
svmRadial 0.6979167 0.7135417 0.7343750 0.7304688 0.7447917 0.7604167    0
rf        0.7291667 0.7604167 0.7682292 0.7690972 0.7760417 0.8125000    0

Kappa 
               Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
svmRadial 0.2637842 0.3570103 0.4053130 0.3917770 0.4394767 0.4775359    0
rf        0.3788379 0.4612661 0.4788076 0.4809233 0.5028566 0.5785880    0

现在是堆栈,

stack.glm <- caretStack(models, method = "glm", metric = "Kappa", trControl = control)
print(stack.glm)
#output
A glm ensemble of 2 base models: svmRadial, rf

Ensemble results:
Generalized Linear Model 

2304 samples
   2 predictor
   2 classes: 'neg', 'pos' 

No pre-processing
Resampling: Cross-Validated (4 fold, repeated 3 times) 
Summary of sample sizes: 1728, 1728, 1728, 1728, 1728, 1728, ... 
Resampling results:

  Accuracy   Kappa    
  0.7667824  0.4685406

或gbm堆栈

stack.gbm <- caretStack(models, method="gbm", metric = "Kappa", trControl = control)

print(stack.gbm)
#output
A gbm ensemble of 2 base models: svmRadial, rf

Ensemble results:
Stochastic Gradient Boosting 

2304 samples
   2 predictor
   2 classes: 'neg', 'pos' 

No pre-processing
Resampling: Cross-Validated (4 fold, repeated 3 times) 
Summary of sample sizes: 1728, 1728, 1728, 1728, 1728, 1728, ... 
Resampling results across tuning parameters:

  interaction.depth  n.trees  Accuracy   Kappa    
  1                   50      0.7693866  0.4832061
  1                  100      0.7675058  0.4785977
  1                  150      0.7663484  0.4753614
  2                   50      0.7662037  0.4748160
  2                  100      0.7638889  0.4684015
  2                  150      0.7634549  0.4653090
  3                   50      0.7630208  0.4657834
  3                  100      0.7612847  0.4606506
  3                  150      0.7569444  0.4511977

Tuning parameter 'shrinkage' was held constant at a value of 0.1
Tuning parameter 'n.minobsinnode' was
 held constant at a value of 10
Kappa was used to select the optimal model using  the largest value.
The final values used for the model were n.trees = 50, interaction.depth = 1, shrinkage = 0.1 and n.minobsinnode
 = 10.

所以k的值为 svm:0.3917770
rf:0.4809233
glm合奏:0.4685406
gbm合奏:0.4832061 - 如果使用更多模型,这可能会更高

编辑:使用OP提供的数据:

首先将year.of.bankruptcy转换为因子

data$year.of.bankruptcy <- as.factor(data$year.of.bankruptcy)

将级别名称设置为不会引发错误的内容:

levels(data$year.of.bankruptcy) <- c("minus", "plus")

继续前进

control <- trainControl(method = "repeatedcv", number = 4, repeats = 3, 
                        savePredictions = "final" , classProbs = TRUE)

models <- caretList(year.of.bankruptcy ~ ., data = data, trControl = control,
                    metric = "Kappa", methodList = algorithmList)

我收到关于零方差预测因子的警告,但这可能是由小数据样本引起的。如果您看到如下错误:

In .local(x, ...) : Variable(s) `' constant. Cannot scale data.

在整个数据集上然后值得研究去除近零方差预测值。关于这个here有一个很好的章节。祝你好运