为什么R包ForecastML每次运行都会产生不同的结果?

时间:2020-11-10 08:23:38

标签: r forecast

我正在尝试使用this overview的ForecastML包来预测多个时间序列。对我来说奇怪的是,每次运行代码时,对于data_combined值我都会得到不同的结果。我运行的完整脚本是:

library(forecastML)
library(dplyr)
library(glmnet)
library(randomForest)

data <- cbind(
c(107,97,102,87,119,106,110,106,107,134,147,180,125,134,110,102,103,111,120,129,122,183,169,190,134,108,104,117,157,148,130,140,136,140,187,150,159,143,114,127,159,156,138,120,117,170,168,198,144,146,109,131,151,140,153,140,161,168,152,136,113,100,103,103,121,134,133,129,144,154,156,163,122,92,117,95,96,108,108,106,140,114,158,161,102,127,125,101,97,112,112,113,108,128,154,162,112,79,82,127,108,110,123,103,97,140,165,183,148,111,116,115,100,106,134,125,117,122,153,178,114,94,128,119,111,110,114,118,115,132,153,171,115,95,92,100,95,114,102,104,132,136,117,137,111,106,98,84,94,105,123,109,130,153,134,99,115,104,131,108,103,115,122,122,125,137,138,152,120,95,100,89,82,89,60,84,113,126,122,118,92,86,81,84,87,90,79,96,122,120,137,154),
c(9059,7685,9963,10955,11823,12391,13460,14055,12106,11372,9834,9267,9130,8933,11000,10733,12912,12926,13990,14926,12900,12034,10643,10742,10266,10281,11527,12281,13587,13049,16055,15220,13824,12729,11467,11351,10803,10548,12368,13311,13885,14088,16932,16164,14883,13532,12220,12025,11692,11081,13745,14382,14391,15597,16834,17282,15779,13946,12701,10431,11616,10808,12421,13605,14455,15019,15662,16745,14717,13756,12531,12568,11249,11096,12637,13018,15005,15235,15552,16905,14776,14104,12854,12956,12177,11918,13517,14417,15911,15589,16543,17925,15406,14601,13107,12268,11972,12028,14033,14244,15287,16954,17361,17694,16222,14969,13624,13842,12387,11608,15021,14834,16565,16882,18012,18855,17243,16045,14745,13726,11196,12105,14723,15582,16863,16758,17434,18359,17189,16909,15380,15161,14027,14478,16155,16585,18117,17552,18299,19361,17924,17872,16058,15746,15226,14932,16846,16854,18146,17559,18655,19453,17923,17915,16496,13544,13601,15667,17358,18112,18581,18759,20668,21040,18993,18668,16768,16551,16231,15511,18308,17793,19205,19162,20997,20705,18759,19240,17504,16591,16224,16670,18539,19759,19584,19976,21486,21626,20195,19928,18564,18149),
c(0.102971811805368,0.102362995884646,0.102062490635914,0.100873300511862,0.101019672891934,0.10058119170287,0.103773981457839,0.104076403554621,0.103773981457839,0.103026401330572,0.102730112155946,0.101997191539847,0.101274563494893,0.10070397563972,0.100139606658898,0.0986211043713023,0.0983492854059603,0.0980801772105387,0.0972792082183714,0.0974106238350488,0.0974252365245483,0.0963806330037465,0.0957389559626943,0.0951063062359475,0.0967359671470176,0.0961092224873678,0.095367254851379,0.0947095915871269,0.0941176202174608,0.0935321548190638,0.0929540494377308,0.0928397862431927,0.0927247362539862,0.0922696509793897,0.0917066851479679,0.0912620719433678,0.090711603254936,0.090276328119195,0.0899519176272147,0.0890996386561615,0.0886791925043499,0.0881592888670634,0.0889020568552906,0.0881813314444876,0.0889402929599117,0.0877266104275971,0.087428846437772,0.0870354301608856,0.0864499193294655,0.0858726409121568,0.0853982218357345,0.083821981233605,0.0845907801489325,0.0841369037739444,0.0837784051341314,0.0835107427259604,0.0828063938633846,0.0811788933269884,0.0828536069623417,0.0941901186933595,0.0923998429510411,0.108161478199019,0.10721168869023,0.114042966782082,0.112454115810183,0.111316253290611,0.11030125221242,0.108197177376865,0.107027443082328,0.104946980916917,0.119357749193208,0.117621904277373,0.133027420877451,0.130845243689729,0.128318477474772,0.123547448292297,0.118586811514179,0.116337480161004,0.11516147558196,0.114501197216867,0.113522979499817,0.111930179432996,0.110610528503361,0.11527438914664,0.113793485966034,0.112349582098189,0.111753469387189,0.109642522576533,0.10844089510559,0.107884938936114,0.109084769191454,0.107571450111271,0.106164022368002,0.106299999323319,0.104825313000088,0.103451745711815,0.101449920129493,0.100402316427863,0.098862033680192,0.102496154313521,0.103027431599736,0.102178908220655,0.0998366428726473,0.0926366895833353,0.0918149629077569,0.090724303768407,0.0900212072768793,0.0893307058230937,0.0884427348717763,0.0883525692744791,0.0867573619308237,0.0849952420449752,0.0845679437213488,0.0844318988774436,0.0843508831482932,0.0836009830491076,0.0834172630524962,0.0827451397987249,0.0852352669035281,0.0847703028296526,0.0844589214084587,0.085352119244763,0.0875592125175749,0.0903829170614837,0.0907832937355188,0.108742780219868,0.114142227335262,0.112992933231466,0.111320706029796,0.109126229280665,0.107698459343112,0.107601574334496,0.103775019202843,0.107114170431059,0.107374774370757,0.111695372689559,0.110638184592354,0.111855211329895,0.109742342683337,0.108193931510232,0.106255362697951,0.104193034427699,0.101933972880902,0.102793824574291,0.10476034144929,0.104002535534347,0.116655515402424,0.11516147558196,0.112989543494316,0.113860643932406,0.119118081064489,0.124489986005886,0.123222945411622,0.12067793212866,0.121048982651421,0.116968571491487,0.112750259392875,0.108079306704711,0.108838515984019,0.111291766408542,0.111304009176187,0.115454357532553,0.114768296055692,0.117207430931122,0.119076397031248,0.117965862171995,0.117449127100423,0.116988457838933,0.11261053571781,0.113657015681422,0.113144445252379,0.118495534815352,0.117969401200945,0.1176866141183,0.120059238961094,0.119437745680998,0.118881271786551,0.118462360710195,0.118016598400236,0.117706622543368,0.117776089941536,0.114796991716514,0.11573525277085,0.115356263024722,0.114815360704668,0.114777477886645,0.114935980147534,0.114796991716514,0.114093156728444,0.116465521799171,0.116026113132354,0.116066729379379),
c(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1)
)
colnames(data) <- c("DriversKilled", "kms", "PetrolPrice", "law")
data <- as.data.frame(data)

date_frequency <- "1 month"
data$PetrolPrice <- round(data$PetrolPrice, 3)
data_train <- data[1:(nrow(data) - 12), ]
data_test <- data[(nrow(data) - 12 + 1):nrow(data), ]

outcome_col <- 1
horizons <- c(1, 3, 6, 12)
lookback <- c(1:6, 9, 12, 15)
dynamic_features <- "law"

data_list <- forecastML::create_lagged_df(data_train, outcome_col = outcome_col, type = "train", horizons = horizons, lookback = lookback, dynamic_features = dynamic_features)
windows <- forecastML::create_windows(data_list, window_length = 0)

data_forecast_list <- forecastML::create_lagged_df(data_train, outcome_col = outcome_col, type = "forecast", horizons = horizons, lookback = lookback, dynamic_features = dynamic_features)
                                                   
for (i in seq_along(data_forecast_list)) {
  data_forecast_list[[i]]$law <- 1
}
                                          
model_function <- function(data) {
  constant_features <- which(unlist(lapply(data[, -1], function(x) {!(length(unique(x)) > 1)})))
  if (length(constant_features) > 1) {
    data <- data[, -c(constant_features + 1)]
  }
  x <- data[, -(1), drop = FALSE]
  y <- data[, 1, drop = FALSE]
  x <- as.matrix(x, ncol = ncol(x))
  y <- as.matrix(y, ncol = ncol(y))
  model <- glmnet::cv.glmnet(x, y, nfolds = 3)
  return(list("model" = model, "constant_features" = constant_features))
}
model_function_2 <- function(data) {
  outcome_names <- names(data)[1]
  model_formula <- formula(paste0(outcome_names,  "~ ."))
  model <- randomForest::randomForest(formula = model_formula, data = data, ntree = 200)
  return(model)
}
                                           
prediction_function <- function(model, data_features) {
  if (length(model$constant_features) > 1) {
    data_features <- data_features[, -c(model$constant_features )]
  }
  x <- as.matrix(data_features, ncol = ncol(data_features))
  data_pred <- data.frame("y_pred" = predict(model$model, x, s = "lambda.min"))
  return(data_pred)
}
prediction_function_2 <- function(model, data_features) {
  data_pred <- data.frame("y_pred" = predict(model, data_features))
  return(data_pred)
}

model_results <- forecastML::train_model(data_list, windows, model_name = "LASSO", model_function, use_future = FALSE)
model_results_2 <- forecastML::train_model(data_list, windows, model_name = "RF", model_function_2, use_future = FALSE)

data_forecast <- predict(model_results, model_results_2, prediction_function = list(prediction_function, prediction_function_2), data = data_forecast_list)
data_forecast$DriversKilled_pred <- round(data_forecast$DriversKilled_pred, 0)
data_error <- forecastML::return_error(data_forecast, data_test = data_test, test_indices = c(181:192))
data_combined <- forecastML::combine_forecasts(data_forecast)

我逐行进行了测试,发现差异从model_results和model_results_2开始。谁能解释这个区别?

0 个答案:

没有答案