具有MLR的makeClassif-从任务中排除的ID列

时间:2019-05-14 17:15:57

标签: r machine-learning gbm mlr

我的数据中有一个ID列。我将此列从trainTask中删除,因为它不是一项功能。但是,我想将预测概率与数据中的实际ID号关联起来。

我要匹配的列是Init_Acct,它是data.frame中的ID号

我的代码如下:

# Make classif tasks
trainTask <- makeClassifTask(
  data = train.df %>% dplyr::select(-Init_Acct) # Init_Acct is the ID I want to match
  , id
  , target = "READMIT_FLAG"
  , positive = "Y"
)
testTask <- makeClassifTask(
  data = test.df %>% dplyr::select(-Init_Acct)
  , target = "READMIT_FLAG"
  , positive = "Y"
)

# Check trainTask and testTask
trainTask <- smote(trainTask, rate = 6)
testTask <- smote(testTask, rate = 6)

# GBM ####
getParamSet('classif.gbm')
gbm.learner <- makeLearner(
  'classif.gbm'
  , predict.type = 'prob'
)
plotLearnerPrediction(gbm.learner, trainTask)

# Tune model
gbm.tune.ctl <- makeTuneControlRandom(maxit = 50L)

# Cross validation
gbm.cv <- makeResampleDesc("CV", iters = 3L)

# Grid search - Hyper-parameter space
gbm.par <- makeParamSet(
  makeDiscreteParam('distribution', values = 'bernoulli')
  , makeIntegerParam('n.trees', lower = 10, upper = 1000)
  , makeIntegerParam('interaction.depth', lower = 2, upper = 10)
  , makeIntegerParam('n.minobsinnode', lower = 10, upper = 80)
  , makeNumericParam('shrinkage', lower = 0.01, upper = 1)
)

# Tune Hyper-parameters
parallelMap::parallelStartSocket(
  4
  , level = "mlr.tuneParams"
)
gbm.tune <- tuneParams(
  learner = gbm.learner
  , task = trainTask
  , resampling = gbm.cv
  , measures = acc
  , par.set = gbm.par
  , control = gbm.tune.ctl
)

parallelMap::parallelStop()

# Check CV acc
gbm.tune$y
gbm.tune$x

# Set hyper-parameters
gbm.ps <- setHyperPars(
  learner = gbm.learner
  , par.vals = gbm.tune$x
)

# Train gbm
gbm.train <- train(gbm.ps, testTask)
plotLearningCurve(
  generateLearningCurveData(
    gbm.learner
    , testTask
  )
)

# Predict
gbm.pred <- predict(gbm.train, testTask)
plotResiduals(gbm.pred)

# Create submission file
gbm.submit <- data.frame(
  gbm.pred$data
)
head(gbm.submit, 5)
table(gbm.submit$truth, gbm.submit$response)

# Confusion Matrix
calculateConfusionMatrix(gbm.pred)
calculateROCMeasures(gbm.pred)
conf_mat_f1_func(gbm.pred)

perf_plots_func(Model = gbm.pred)

数据看起来像这样:

glimpse(train.df)
Observations: 33,031
Variables: 17
$ Init_Acct         <chr> "12345678", "87654321", "81734650", "11223344", "1422...
$ Init_LOS          <dbl> 2, 2, 5, 1, 12, 3, 16, 9, 3, 14, 1, 1, 4, 7, 4, 1, 3,...
$ Init_LACE         <dbl> 2, 7, 7, 9, 8, 8, 11, 10, 8, 10, 5, 4, 8, 8, 4, 5, 3,...
$ READMIT_FLAG      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, Y,...
$ Init_Hosp_Pvt     <fct> PRIVATE, HOSPITALIST, HOSPITALIST, HOSPITALIST, PRIVA...
$ Age_at_Init_Admit <dbl> 37, 26, 56, 67, 51, 53, 48, 57, 92, 67, 72, 22, 60, 6...
$ Age_Bucket        <fct> 3, 2, 5, 6, 5, 5, 4, 5, 9, 6, 7, 2, 6, 6, 7, 6, 9, 5,...
$ Gender            <fct> F, M, F, M, M, F, M, F, M, M, M, F, M, F, F, F, F, M,...
$ Init_ROM          <dbl> 1, 1, 3, 4, 2, 3, 1, 3, 4, 1, 1, 1, 1, 1, 2, 1, 2, 4,...
$ Init_SOI          <dbl> 1, 1, 3, 4, 3, 3, 3, 3, 3, 2, 1, 1, 2, 2, 2, 2, 2, 4,...
$ Has_Diabetes      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,...
$ reduced_dispo     <fct> AHR, AHR, AHR, ATH, ATW, ATW, ATW, AHR, AHR, ATW, AHR...
$ reduced_hsvc      <fct> SUR, MED, MED, Other, MED, MED, MED, MED, MED, MED, M...
$ reduced_abucket   <fct> 3, 2, 5, 6, 5, 5, 4, 5, Other, 6, 7, 2, 6, 6, 7, 6, O...
$ reduced_spclty    <fct> Other, HOSIM, HOSIM, HOSIM, Other, HOSIM, HOSIM, HOSI...
$ reduced_lihn      <fct> Other, Medical, Pneumonia, Medical, Medical, Medical,...
$ discharge_month   <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...

输出:

glimpse(gbm.submit)
Observations: 23,896
Variables: 5
$ id       <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,...
$ truth    <fct> Y, N, N, N, N, N, N, Y, N, N, N, N, N, N, Y, N, Y, N, N, N, N,...
$ prob.N   <dbl> 0.9150623, 0.7914781, 0.9661108, 0.9198683, 0.8502536, 0.94376...
$ prob.Y   <dbl> 0.08493774, 0.20852192, 0.03388919, 0.08013167, 0.14974644, 0....
$ response <fct> N, N, N, N, N, N, N, Y, N, N, N, N, N, N, N, N, Y, N, N, N, N,...

1 个答案:

答案 0 :(得分:1)

MLR的('artificial', 'intelligence')保留行名,并在其输出中产生一个附加的predict()列,该列为原始数据建立索引。您可以使用任一个将预测与原始样本ID相关联。

设置

id

选项1:使用ID列为原始数据编制索引

library(tidyverse)
library(mlr)

## Add a custom sample ID column
iris2 <- iris %>% mutate(Init_Acct = paste0("Acct",1:n()))
lrn <- makeLearner( "classif.gbm", predict.type="prob" )

选项2:使用行名

## Drop the custom column as in your original post
task <- makeClassifTask( data=select(iris2, -Init_Acct), target="Species" )
mdl <- train( lrn, task )
pred <- predict( mdl, task )

## Join against the original data by the "id" column
iris2 %>% mutate(id=1:n()) %>% select(Init_Acct, id) %>% 
    inner_join( pred$data ) %>% select(-id)
#   Init_Acct  truth prob.setosa prob.versicolor prob.virginica response
# 1     Acct1 setosa   0.9998775    1.225043e-04   2.836942e-08   setosa
# 2     Acct2 setosa   0.9999652    3.468690e-05   1.118015e-07   setosa
# 3     Acct3 setosa   0.9999538    4.611200e-05   8.389636e-08   setosa