预处理数据后如何在整洁模型R中进行预测

时间:2020-08-04 02:57:21

标签: r linear-regression prediction preprocessor tidymodels

您好,我正在尝试使用tidymodels建立线性回归模型的示例,我设法使用框架正确地拟合了模型,并使用collect_metrics()和collect_predictions()在工作流中对其进行了测试。但是,当我尝试使用模型对新数据进行预测时,我无法使其正常工作。我正在尝试修改此example


rf_wflow_final_fit <- fit(rf_wflow_final, data = dia_train)

dia_rec3     <- pull_workflow_prepped_recipe(rf_wflow_final_fit)
rf_final_fit <- pull_workflow_fit(rf_wflow_final_fit)

dia_test$.pred <- predict(rf_final_fit, 
                          new_data = bake(dia_rec3, dia_test))$.pred
dia_test$logprice <- log(dia_test$price)

metrics(dia_test, truth = logprice, estimate = .pred)
#> # A tibble: 3 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard      0.113 
#> 2 rsq     standard      0.988 
#> 3 mae     standard      0.0846

这就是我在做什么:

data("diamonds")
set.seed(234589)
diamonds_split <- initial_split(diamonds, prop = 4/5)

diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

diamonds_recipe <- 
  recipe(price ~ ., data = diamonds_train) %>%
  step_log(all_outcomes()) %>%
  step_normalize(all_predictors(), -all_nominal()) %>%
  step_dummy(all_nominal()) %>%
  step_poly(carat, degree = 2)

preprocesados <- prep(diamonds_recipe)

lr_model <- 
  linear_reg()%>%
  set_engine("lm") %>%
  set_mode("regression")

lr_workflow <- workflow() %>%
  add_recipe(diamonds_recipe) %>%
  add_model(lr_model)

lr_fitted_workflow <-  lr_workflow %>%
  last_fit(diamonds_split)

performance <- lr_fitted_workflow %>% collect_metrics()
test_predictions <- lr_fitted_workflow %>% collect_predictions()

final_model <- fit(lr_workflow, diamonds)

到目前为止,一切似乎都正常,当我尝试使用预测函数时出现错误

我已经尝试过了:

predict(final_model, new_data = bake(preprocesados, diamonds_test))

Error: The following required columns are missing: 'carat', 'cut', 'color', 'clarity'.
Traceback:

1. predict(final_model, new_data = bake(preprocesados, diamonds_test))
2. predict.workflow(final_model, new_data = bake(preprocesados, 
 .     diamonds_test))
3. hardhat::forge(new_data, blueprint)
4. forge.data.frame(new_data, blueprint)
5. blueprint$forge$clean(blueprint = blueprint, new_data = new_data, 
 .     outcomes = outcomes)
6. shrink(new_data, blueprint$ptypes$predictors)
7. validate_column_names(data, cols)
8. glubort("The following required columns are missing: {missing_names}.")
9. abort(glue(..., .sep = .sep, .envir = .envir))
10. signal_abort(cnd)

这:

new_diamond <- tribble(~carat, ~cut, ~color, ~clarity, ~depth, ~table, ~x, ~y, ~z,
                        0.23,   "Ideal",    "E",    "SI2",  61.5,   55, 3.95, 3.98, 2.43)

predict(final_model, new_data = bake(preprocesados, new_diamond))

Warning message:
“ There were 3 columns that were factors when the recipe was prepped:
 'cut', 'color', 'clarity'.
 This may cause errors when processing new data.”

Error: Assigned data `log(new_data[[col_names[i]]] + object$offset, base = object$base)` must be compatible with existing data.
✖ Existing data has 1 row.
✖ Assigned data has 0 rows.
ℹ Row updates require a list value. Do you need `list()` or `as.list()`?
Traceback:

1. predict(final_model, new_data = bake(preprocesados, new_diamond))
2. predict.workflow(final_model, new_data = bake(preprocesados, 
 .     new_diamond))
3. hardhat::forge(new_data, blueprint)
4. bake(preprocesados, new_diamond)
5. bake.recipe(preprocesados, new_diamond)
6. bake(object$steps[[i]], new_data = new_data)
7. bake.step_log(object$steps[[i]], new_data = new_data)
8. `[<-`(`*tmp*`, , col_names[i], value = numeric(0))
9. `[<-.tbl_df`(`*tmp*`, , col_names[i], value = numeric(0))
10. tbl_subassign(x, i, j, value, i_arg, j_arg, substitute(value))
...

任何帮助都会非常感激

1 个答案:

答案 0 :(得分:2)

尽量不要将烘焙与工作流程混合使用,并记住使用all_outcomes时通常需要跳过步骤

library(tidymodels)
#> -- Attaching packages --------------------------------------------------------------------------------------------- tidymodels 0.1.1 --
#> v broom     0.7.0      v recipes   0.1.13
#> v dials     0.0.8      v rsample   0.0.7 
#> v dplyr     1.0.0      v tibble    3.0.3 
#> v ggplot2   3.3.2      v tidyr     1.1.0 
#> v infer     0.5.3      v tune      0.1.1 
#> v modeldata 0.0.2      v workflows 0.1.2 
#> v parsnip   0.1.2      v yardstick 0.0.7 
#> v purrr     0.3.4
#> -- Conflicts ------------------------------------------------------------------------------------------------ tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
data("diamonds")
set.seed(234589)
diamonds_split <- initial_split(diamonds, prop = 4/5)

diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

diamonds_recipe <- 
  recipe(price ~ ., data = diamonds_train) %>%
  step_log(all_outcomes(),skip = T) %>%
  step_normalize(all_predictors(), -all_nominal()) %>%
  step_dummy(all_nominal()) %>%
  step_poly(carat, degree = 2)

preprocesados <- prep(diamonds_recipe)

lr_model <- 
  linear_reg()%>%
  set_engine("lm") %>%
  set_mode("regression")

lr_workflow <- workflow() %>%
  add_recipe(diamonds_recipe) %>%
  add_model(lr_model)

final_model <- fit(lr_workflow, diamonds)

predict(final_model, new_data = diamonds_test)
#> # A tibble: 10,787 x 1
#>    .pred
#>    <dbl>
#>  1  5.94
#>  2  5.91
#>  3  5.87
#>  4  6.23
#>  5  6.22
#>  6  6.29
#>  7  6.05
#>  8  6.08
#>  9  6.35
#> 10  6.04
#> # ... with 10,777 more rows

reprex package(v0.3.0)于2020-08-04创建