在R中绘制二元目标的部分依赖图(mlr)

时间:2018-03-28 14:58:57

标签: r machine-learning visualization mlr

我有一个问题是使用mlr获得部分依赖图以便我正常工作。不知何故,概率不是plottet,而只是类标签。我怀疑,在创建部分依赖性数据期间,目标可能会丢失。

有什么想法吗?

library(mlr)
library(dplyr)
library(ranger)

# select subset
iris_bin <- iris %>% 
  filter(Species != "virginica") %>% 
  mutate(bin_target = ifelse(Species == "setosa", TRUE, FALSE)) %>% 
  select(-Species)

# fit model
task_bin <- makeClassifTask(data = iris_bin, target = "bin_target")
lrn_bin  <- makeLearner("classif.ranger", predict.type = "prob")
fit_bin <- train(lrn_bin, task_bin)

# create partial dependence plot
pd <- generatePartialDependenceData(fit_bin, task_bin, "Sepal.Length")

pd  # is the target correct?
#> PartialDependenceData
#> Task: iris_bin
#> Features: Sepal.Length
#> Target: FALSE
#> Derivative: FALSE
#> Interaction: FALSE
#> Individual: FALSE
#>        FALSE Sepal.Length
#> 1: 0.4920347          4.3
#> 2: 0.4920347          4.6
#> 3: 0.4935947          4.9
#> 4: 0.4945947          5.2
#> 5: 0.5104600          5.5
#> 6: 0.5107800          5.8
#> ... (#rows: 10, #cols: 2)
plotPartialDependence(pd)

enter image description here

这将是我当前会议的详细信息,也许这会有所帮助?:

Session info ---------------------------------------------
 setting  value                       
 version  R version 3.4.2 (2017-09-28)
 system   x86_64, mingw32             
 ui       RStudio (1.1.383)           
 language (EN)                        
 collate  German_Germany.1252         
 tz       Europe/Berlin               
 date     2018-03-29                  

Packages ----------------------------------------------------
 package      * version    date       source                                   
 assertthat     0.2.0      2017-04-11 CRAN (R 3.4.3)                           
 backports      1.1.2      2017-12-13 CRAN (R 3.4.3)                           
 base         * 3.4.2      2017-09-28 local                                    
 BBmisc         1.11       2017-03-10 CRAN (R 3.4.3)                           
 bindr          0.1.1      2018-03-13 CRAN (R 3.4.2)                           
 bindrcpp     * 0.2        2017-06-17 CRAN (R 3.4.3)                           
 checkmate      1.8.5      2017-10-24 CRAN (R 3.4.3)                           
 colorspace     1.3-2      2016-12-14 CRAN (R 3.4.3)                           
 compiler       3.4.2      2017-09-28 local                                    
 data.table     1.10.4-3   2017-10-27 CRAN (R 3.4.3)                           
 datasets     * 3.4.2      2017-09-28 local                                    
 devtools       1.13.5     2018-02-18 CRAN (R 3.4.3)                           
 digest         0.6.15     2018-01-28 CRAN (R 3.4.3)                           
 dplyr        * 0.7.4      2017-09-28 CRAN (R 3.4.3)                           
 ggplot2        2.2.1.9000 2018-03-26 Github (tidyverse/ggplot2@3c9c504)       
 glue           1.2.0      2017-10-29 CRAN (R 3.4.3)                           
 graphics     * 3.4.2      2017-09-28 local                                    
 grDevices    * 3.4.2      2017-09-28 local                                    
 grid           3.4.2      2017-09-28 local                                    
 gtable         0.2.0      2016-02-26 CRAN (R 3.4.3)                           
 labeling       0.3        2014-08-23 CRAN (R 3.4.1)                           
 lattice        0.20-35    2017-03-25 CRAN (R 3.4.2)                           
 lazyeval       0.2.1      2017-10-29 CRAN (R 3.4.3)                           
 magrittr       1.5        2014-11-22 CRAN (R 3.4.3)                           
 Matrix         1.2-11     2017-08-21 CRAN (R 3.4.2)                           
 memoise        1.1.0      2017-04-21 CRAN (R 3.4.3)                           
 methods      * 3.4.2      2017-09-28 local                                    
 mlr          * 2.13       2018-03-28 Github (mlr-org/mlr@a9036e3)             
 mmpf         * 0.0.4      2017-12-05 CRAN (R 3.4.4)                           
 munsell        0.4.3      2016-02-13 CRAN (R 3.4.3)                           
 parallel       3.4.2      2017-09-28 local                                    
 parallelMap    1.3        2015-06-10 CRAN (R 3.4.3)                           
 ParamHelpers * 1.11       2018-02-19 Github (berndbischl/ParamHelpers@59c649e)
 pillar         1.2.1      2018-02-27 CRAN (R 3.4.3)                           
 pkgconfig      2.0.1      2017-03-21 CRAN (R 3.4.3)                           
 plyr           1.8.4      2016-06-08 CRAN (R 3.4.3)                           
 R6             2.2.2      2017-06-17 CRAN (R 3.4.3)                           
 ranger       * 0.9.0      2018-01-09 CRAN (R 3.4.3)                           
 Rcpp           0.12.16    2018-03-13 CRAN (R 3.4.2)                           
 rlang          0.2.0.9001 2018-03-26 Github (r-lib/rlang@49d7a34)             
 rstudioapi     0.7        2017-09-07 CRAN (R 3.4.3)                           
 scales         0.5.0      2017-08-24 CRAN (R 3.4.4)                           
 splines        3.4.2      2017-09-28 local                                    
 stats        * 3.4.2      2017-09-28 local                                    
 stringi        1.1.7      2018-03-12 CRAN (R 3.4.2)                           
 survival       2.41-3     2017-04-04 CRAN (R 3.4.2)                           
 tibble         1.4.2      2018-01-22 CRAN (R 3.4.3)                           
 tools          3.4.2      2017-09-28 local                                    
 utils        * 3.4.2      2017-09-28 local                                    
 withr          2.1.2      2018-03-26 Github (jimhester/withr@79d7b0d)         
 XML            3.98-1.10  2018-02-19 CRAN (R 3.4.3)                           
 yaml           2.1.18     2018-03-08 CRAN (R 3.4.3)

1 个答案:

答案 0 :(得分:2)

希望mlr软件包维护者可以提供帮助(我不使用该软件包)。但是,在此期间,您可以直接调整模型,只需使用pdp包:

fit <- ranger(as.factor(bin_target) ~ ., data = iris_bin, 
              probability = TRUE)
library(ggplot2)
library(pdp)
pd <- partial(fit, pred.var = "Sepal.Length", prob = TRUE)
autoplot(pd)

请注意在prob = TRUE的调用中使用partial。另外,ggplot2不是必需的,因为您可以使用plotPartial(pd)代替lattice图片。

此外,您仍然可以使用mlr调整模型,然后使用partial;例如,

library(mlr)
library(dplyr)
library(ranger)
library(pdp)

# select subset
iris_bin <- iris %>% 
  filter(Species != "virginica") %>% 
  mutate(bin_target = ifelse(Species == "setosa", TRUE, FALSE)) %>% 
  select(-Species)

# fit model
task_bin <- makeClassifTask(data = iris_bin, target = "bin_target")
lrn_bin  <- makeLearner("classif.ranger", predict.type = "prob")
fit_bin <- train(lrn_bin, task_bin)

# partial dependence plot
mod <- getLearnerModel(fit_bin)  # EXTRACT THE MODEL!!  <<--
partial(mod, pred.var = "Sepal.Length", prob = TRUE, 
        plot = TRUE, train = iris_bin)

但请注意,需要通过train参数提供原始训练数据。