尚未使用`probability = TRUE'训练SVM,概率不可用于预测

时间:2020-01-08 16:57:09

标签: mlr3

尝试使用mlr3输出SVM的预测概率时遇到了问题。

<mat-form-field fxFlex>
  <input matInput formControlName="headerCoordinate">
  <mat-error *ngIf="selectDataFormGroup.get('headerCoordinate').hasError('required')">{{ 'Field is required' | translate }}</mat-error>
  <mat-error *ngIf="selectDataFormGroup.get('headerCoordinate').hasError('pattern')">{{ 'Field do not have the right format' | translate }}</mat-error>
</mat-form-field>

我知道SVM不会输出概率,但是SVM可以将预测数据拟合到隔离超平面函数,并从超平面获得带符号的距离度量。我想检索带符号的距离,然后使用它们来计算AUC。 但是使用library(mlr3) task = mlr_tasks$get("iris") svm_learner = mlr_learners$get("classif.svm") train_set = sample(task$nrow, 0.8 * task$nrow) test_set = setdiff(seq_len(task$nrow), train_set) svm_learner$train(task, row_ids = task$row_ids[train_set]) svm_learner$predict_type<-"prob" prediction<-svm_learner$predict(task,row_ids = task$row_ids[test_set]) prediction Warning message: In predict.svm(self$model, newdata = newdata, probability = (self$predict_type == : SVM has not been trained using `probability = TRUE`, probabilities not available for predictions. Session info > sessionInfo(package = NULL) R version 3.6.2 (2019-12-12) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 17763) Matrix products: default Random number generation: RNG: Mersenne-Twister Normal: Inversion Sample: Rounding locale: [1] LC_COLLATE=English_United States.1252 LC_CTYPE=English_United States.1252 [3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C [5] LC_TIME=English_United States.1252 attached base packages: [1] stats graphics grDevices utils datasets methods base other attached packages: [1] precrec_0.10.1 forcats_0.4.0 stringr_1.4.0 purrr_0.3.3 readr_1.3.1 [6] tidyr_1.0.0 tibble_2.1.3 tidyverse_1.2.1 dplyr_0.8.3 mlr3learners_0.1.5 [11] GGally_1.4.0 ggplot2_3.2.1 mlr3_0.1.6 mlr3viz_0.1.0 e1071_1.7-3 [16] biomaRt_2.38.0 loaded via a namespace (and not attached): [1] Biobase_2.42.0 httr_1.4.1 bit64_0.9-7 jsonlite_1.6 [5] modelr_0.1.4 assertthat_0.2.1 lgr_0.3.3 stats4_3.6.2 [9] blob_1.2.0 cellranger_1.1.0 mlr3misc_0.1.6 progress_1.2.2 [13] pillar_1.4.3 RSQLite_2.1.2 backports_1.1.5 lattice_0.20-38 [17] glue_1.3.1 uuid_0.1-2 digest_0.6.23 RColorBrewer_1.1-2 [21] checkmate_1.9.4 rvest_0.3.3 colorspace_1.4-1 plyr_1.8.5 [25] XML_3.98-1.20 pkgconfig_2.0.3 mlr3measures_0.1.1 broom_0.5.2 [29] haven_2.1.0 scales_1.0.0 generics_0.0.2 IRanges_2.16.0 [33] withr_2.1.2 BiocGenerics_0.28.0 lazyeval_0.2.2 cli_2.0.0 [37] magrittr_1.5 crayon_1.3.4 readxl_1.3.1 paradox_0.1.0 [41] memoise_1.1.0 fansi_0.4.0 nlme_3.1-142 xml2_1.2.0 [45] class_7.3-15 tools_3.6.2 data.table_1.12.8 prettyunits_1.0.2 [49] hms_0.5.2 lifecycle_0.1.0 S4Vectors_0.20.1 munsell_0.5.0 [53] AnnotationDbi_1.44.0 compiler_3.6.2 rlang_0.4.1 grid_3.6.2 [57] RCurl_1.95-4.12 rstudioapi_0.10 bitops_1.0-6 labeling_0.3 [61] gtable_0.3.0 DBI_1.0.0 reshape_0.8.8 reshape2_1.4.3 [65] R6_2.4.1 lubridate_1.7.4 bit_1.1-14 zeallot_0.1.0 [69] stringi_1.4.3 parallel_3.6.2 Rcpp_1.0.2 vctrs_0.2.1 [73] tidyselect_0.2.5 ,我只能得到预测的类,而不能得到带符号的距离。使用predict_type<-"response",我得到了上面的错误。

1 个答案:

答案 0 :(得分:1)

您的代码是向后的。对其进行如下修改:

library(mlr3)
task = mlr_tasks$get("iris")
svm_learner = mlr_learners$get("classif.svm")
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)


svm_learner$predict_type<-"prob"
svm_learner$train(task, row_ids = task$row_ids[train_set])
prediction<-svm_learner$predict(task,row_ids = task$row_ids[test_set])
prediction

请注意先更改predict_type,然后再进行培训。