函数作用于一个数据集,而不作用于另一数据集。 LSTM模型

时间:2019-01-08 13:54:58

标签: r

我正在使用以下代码对R中的太阳斑数据运行一个函数:

# Core Tidyverse
library(tidyverse)
library(glue)
library(forcats)

# Time Series
library(timetk)
library(tidyquant)
library(tibbletime)

# Visualization
library(cowplot)

# Preprocessing
library(recipes)

# Sampling / Accuracy
library(rsample)
library(yardstick) 

# Modeling
library(keras)

sun_spots <- datasets::sunspot.month %>%
  tk_tbl() %>%
  mutate(index = as_date(index)) %>%
  as_tbl_time(index = index)

sun_spots
############################################

periods_train <- 12 * 50
periods_test  <- 12 * 10
skip_span     <- 12 * 20

rolling_origin_resamples <- rolling_origin(
  sun_spots,
  initial    = periods_train,
  assess     = periods_test,
  cumulative = FALSE,
  skip       = skip_span
)

rolling_origin_resamples


predict_keras_lstm <- function(split, epochs = 300, ...) {

  lstm_prediction <- function(split, epochs, ...) {

    # 5.1.2 Data Setup
    df_trn <- training(split)
    df_tst <- testing(split)

    df <- bind_rows(
      df_trn %>% add_column(key = "training"),
      df_tst %>% add_column(key = "testing")
    ) %>% 
      as_tbl_time(index = index)

    # 5.1.3 Preprocessing
    rec_obj <- recipe(value ~ ., df) %>%
      step_sqrt(value) %>%
      step_center(value) %>%
      step_scale(value) %>%
      prep()

    df_processed_tbl <- bake(rec_obj, df)

    center_history <- rec_obj$steps[[2]]$means["value"]
    scale_history  <- rec_obj$steps[[3]]$sds["value"]

    # 5.1.4 LSTM Plan
    lag_setting  <- 120 # = nrow(df_tst)
    batch_size   <- 40
    train_length <- 440
    tsteps       <- 1
    epochs       <- epochs

    # 5.1.5 Train/Test Setup
    lag_train_tbl <- df_processed_tbl %>%
      mutate(value_lag = lag(value, n = lag_setting)) %>%
      filter(!is.na(value_lag)) %>%
      filter(key == "training") %>%
      tail(train_length)

    x_train_vec <- lag_train_tbl$value_lag
    x_train_arr <- array(data = x_train_vec, dim = c(length(x_train_vec), 1, 1))

    y_train_vec <- lag_train_tbl$value
    y_train_arr <- array(data = y_train_vec, dim = c(length(y_train_vec), 1))

    lag_test_tbl <- df_processed_tbl %>%
      mutate(
        value_lag = lag(value, n = lag_setting)
      ) %>%
      filter(!is.na(value_lag)) %>%
      filter(key == "testing")

    x_test_vec <- lag_test_tbl$value_lag
    x_test_arr <- array(data = x_test_vec, dim = c(length(x_test_vec), 1, 1))

    y_test_vec <- lag_test_tbl$value
    y_test_arr <- array(data = y_test_vec, dim = c(length(y_test_vec), 1))

    # 5.1.6 LSTM Model
    model <- keras_model_sequential()

    model %>%
      layer_lstm(units            = 50, 
                 input_shape      = c(tsteps, 1), 
                 batch_size       = batch_size,
                 return_sequences = TRUE, 
                 stateful         = TRUE) %>% 
      layer_lstm(units            = 50, 
                 return_sequences = FALSE, 
                 stateful         = TRUE) %>% 
      layer_dense(units = 1)

    model %>% 
      compile(loss = 'mae', optimizer = 'adam')

    # 5.1.7 Fitting LSTM
    for (i in 1:epochs) {
      model %>% fit(x          = x_train_arr, 
                    y          = y_train_arr, 
                    batch_size = batch_size,
                    epochs     = 1, 
                    verbose    = 1, 
                    shuffle    = FALSE)

      model %>% reset_states()
      cat("Epoch: ", i)

    }

    # 5.1.8 Predict and Return Tidy Data
    # Make Predictions
    pred_out <- model %>% 
      predict(x_test_arr, batch_size = batch_size) %>%
      .[,1] 

    # Retransform values
    pred_tbl <- tibble(
      index   = lag_test_tbl$index,
      value   = (pred_out * scale_history + center_history)^2
    ) 

    # Combine actual data with predictions
    tbl_1 <- df_trn %>%
      add_column(key = "actual")

    tbl_2 <- df_tst %>%
      add_column(key = "actual")

    tbl_3 <- pred_tbl %>%
      add_column(key = "predict")

    # Create time_bind_rows() to solve dplyr issue
    time_bind_rows <- function(data_1, data_2, index) {
      index_expr <- enquo(index)
      bind_rows(data_1, data_2) %>%
        as_tbl_time(index = !! index_expr)
    }

    ret <- list(tbl_1, tbl_2, tbl_3) %>%
      reduce(time_bind_rows, index = index) %>%
      arrange(key, index) %>%
      mutate(key = as_factor(key))

    return(ret)

  }

  safe_lstm <- possibly(lstm_prediction, otherwise = NA)

  safe_lstm(split, epochs, ...)

}

#################################################

sample_predictions_lstm_tbl <- rolling_origin_resamples %>%
  mutate(predict = map(splits, predict_keras_lstm, epochs = 3))

sample_predictions_lstm_tbl


sample_predictions_lstm_tbl$predict

哪个给我以下输出(对于Split 11):

[[11]]
# A time tibble: 840 x 3
# Index: index
   index      value key   
   <date>     <dbl> <fct> 
 1 1949-11-01 144.  actual
 2 1949-12-01 118.  actual
 3 1950-01-01 102.  actual
 4 1950-02-01  94.8 actual
 5 1950-03-01 110.  actual
 6 1950-04-01 113.  actual
 7 1950-05-01 106.  actual
 8 1950-06-01  83.6 actual
 9 1950-07-01  91   actual
10 1950-08-01  85.2 actual
# ... with 830 more rows

但是,当我对数据运行以下脚本时,可以获得NA结果,但数据结构与sun_spots数据相同。

sun_spots数据结构:

> str(sun_spots)
Classes ‘tbl_time’, ‘tbl_df’, ‘tbl’ and 'data.frame':   3177 obs. of  2 variables:
 $ index: Date, format: "1749-01-01" "1749-02-01" "1749-03-01" "1749-04-01" ...
 $ value: num  58 62.6 70 55.7 85 83.5 94.8 66.3 75.9 75.5 ...
 - attr(*, "index_quo")= language ~index
  ..- attr(*, ".Environment")=<environment: 0x000000001a339268> 
 - attr(*, "index_time_zone")= chr "UTC"

我的数据结构:

> str(store)
Classes ‘tbl_time’, ‘tbl_df’, ‘tbl’ and 'data.frame':   252 obs. of  2 variables:
 $ index: Date, format: "2007-12-31" "2008-01-07" "2008-01-14" "2008-01-21" ...
 $ value: num  761727 857102 749136 1237957 793982 ...
 - attr(*, "index_quo")= language ~index
  ..- attr(*, ".Environment")=<environment: R_GlobalEnv> 
 - attr(*, "index_time_zone")= chr "UTC"

我有一个名为store的数据框,并使用以下内容创建了滚动示例。

periods_train <- 4 * 50
periods_test  <- 1 * 50

rolling_origin_resamples <- rolling_origin(
  store,
  initial    = periods_train,
  assess     = periods_test,
  cumulative = FALSE
)

rolling_origin_resamples$splits

我创建了与sun_spots数据相同的函数。

predict_keras_lstm <- function(split, epochs = 300, ...) {

  lstm_prediction <- function(split, epochs, ...) {

    # 5.1.2 Data Setup
    df_trn <- training(split)
    df_tst <- testing(split)

    df <- bind_rows(
      df_trn %>% add_column(key = "training"),
      df_tst %>% add_column(key = "testing")
    ) %>% 
      as_tbl_time(index = index)

    # 5.1.3 Preprocessing
    rec_obj <- recipe(value ~ ., df) %>%
      step_sqrt(value) %>%
      step_center(value) %>%
      step_scale(value) %>%
      prep()

    df_processed_tbl <- bake(rec_obj, df)

    center_history <- rec_obj$steps[[2]]$means["value"]
    scale_history  <- rec_obj$steps[[3]]$sds["value"]

    # 5.1.4 LSTM Plan
    lag_setting  <- 120 # = nrow(df_tst)
    batch_size   <- 40
    train_length <- 440
    tsteps       <- 1
    epochs       <- epochs

    # 5.1.5 Train/Test Setup
    lag_train_tbl <- df_processed_tbl %>%
      mutate(value_lag = lag(value, n = lag_setting)) %>%
      filter(!is.na(value_lag)) %>%
      filter(key == "training") %>%
      tail(train_length)

    x_train_vec <- lag_train_tbl$value_lag
    x_train_arr <- array(data = x_train_vec, dim = c(length(x_train_vec), 1, 1))

    y_train_vec <- lag_train_tbl$value
    y_train_arr <- array(data = y_train_vec, dim = c(length(y_train_vec), 1))

    lag_test_tbl <- df_processed_tbl %>%
      mutate(
        value_lag = lag(value, n = lag_setting)
      ) %>%
      filter(!is.na(value_lag)) %>%
      filter(key == "testing")

    x_test_vec <- lag_test_tbl$value_lag
    x_test_arr <- array(data = x_test_vec, dim = c(length(x_test_vec), 1, 1))

    y_test_vec <- lag_test_tbl$value
    y_test_arr <- array(data = y_test_vec, dim = c(length(y_test_vec), 1))

    # 5.1.6 LSTM Model
    model <- keras_model_sequential()

    model %>%
      layer_lstm(units            = 50, 
                 input_shape      = c(tsteps, 1), 
                 batch_size       = batch_size,
                 return_sequences = TRUE, 
                 stateful         = TRUE) %>% 
      layer_lstm(units            = 50, 
                 return_sequences = FALSE, 
                 stateful         = TRUE) %>% 
      layer_dense(units = 1)

    model %>% 
      compile(loss = 'mae', optimizer = 'adam')

    # 5.1.7 Fitting LSTM
    for (i in 1:epochs) {
      model %>% fit(x          = x_train_arr, 
                    y          = y_train_arr, 
                    batch_size = batch_size,
                    epochs     = 1, 
                    verbose    = 1, 
                    shuffle    = FALSE)

      model %>% reset_states()
      cat("Epoch: ", i)

    }

    # 5.1.8 Predict and Return Tidy Data
    # Make Predictions
    pred_out <- model %>% 
      predict(x_test_arr, batch_size = batch_size) %>%
      .[,1] 

    # Retransform values
    pred_tbl <- tibble(
      index   = lag_test_tbl$index,
      value   = (pred_out * scale_history + center_history)^2
    ) 

    # Combine actual data with predictions
    tbl_1 <- df_trn %>%
      add_column(key = "actual")

    tbl_2 <- df_tst %>%
      add_column(key = "actual")

    tbl_3 <- pred_tbl %>%
      add_column(key = "predict")

    # Create time_bind_rows() to solve dplyr issue
    time_bind_rows <- function(data_1, data_2, index) {
      index_expr <- enquo(index)
      bind_rows(data_1, data_2) %>%
        as_tbl_time(index = !! index_expr)
    }

    ret <- list(tbl_1, tbl_2, tbl_3) %>%
      reduce(time_bind_rows, index = index) %>%
      arrange(key, index) %>%
      mutate(key = as_factor(key))

    return(ret)

  }

  safe_lstm <- possibly(lstm_prediction, otherwise = NA)

  safe_lstm(split, epochs, ...)

}

我运行以下命令,以运行模型和函数:

results <- store %>%
  mutate(predict = map(splits, predict_keras_lstm, epochs = 2))

results$predict

这一次我得到了NA值的列表:

[[1]]
[1] NA

[[2]]
[1] NA

[[3]]
[1] NA

我要去哪里错了?为什么我没有在这里得到值列表?

数据:

    store <- structure(list(index = structure(c(13878, 13885, 13892, 13899, 
13906, 13913, 13920, 13927, 13934, 13941, 13948, 13955, 13962, 
13969, 13976, 13983, 13990, 13997, 14004, 14011, 14018, 14025, 
14032, 14039, 14046, 14053, 14060, 14067, 14074, 14081, 14088, 
14095, 14102, 14109, 14116, 14123, 14130, 14137, 14144, 14151, 
14158, 14165, 14172, 14179, 14186, 14193, 14200, 14207, 14214, 
14221, 14228, 14235, 14242, 14249, 14256, 14263, 14270, 14277, 
14284, 14291, 14298, 14305, 14312, 14319, 14326, 14333, 14340, 
14347, 14354, 14361, 14368, 14375, 14382, 14389, 14396, 14403, 
14410, 14417, 14424, 14431, 14438, 14445, 14452, 14459, 14466, 
14473, 14480, 14487, 14494, 14501, 14508, 14515, 14522, 14529, 
14536, 14543, 14550, 14557, 14564, 14571, 14578, 14585, 14592, 
14599, 14606, 14613, 14620, 14627, 14634, 14641, 14648, 14655, 
14662, 14669, 14676, 14683, 14690, 14697, 14704, 14711, 14718, 
14725, 14732, 14739, 14746, 14753, 14760, 14767, 14774, 14781, 
14788, 14795, 14802, 14809, 14816, 14823, 14830, 14837, 14844, 
14851, 14858, 14865, 14872, 14879, 14886, 14893, 14900, 14907, 
14914, 14921, 14928, 14935, 14942, 14949, 14956, 14963, 14970, 
14977, 14984, 14991, 14998, 15005, 15012, 15019, 15026, 15033, 
15040, 15047, 15054, 15061, 15068, 15075, 15082, 15089, 15096, 
15103, 15110, 15117, 15124, 15131, 15138, 15145, 15152, 15159, 
15166, 15173, 15180, 15187, 15194, 15201, 15208, 15215, 15222, 
15229, 15236, 15243, 15250, 15257, 15264, 15271, 15278, 15285, 
15292, 15299, 15306, 15313, 15320, 15327, 15334, 15341, 15348, 
15355, 15362, 15369, 15376, 15383, 15390, 15397, 15404, 15411, 
15418, 15425, 15432, 15439, 15446, 15453, 15460, 15467, 15474, 
15481, 15488, 15495, 15502, 15509, 15516, 15523, 15530, 15537, 
15544, 15551, 15558, 15565, 15572, 15579, 15586, 15593, 15600, 
15607, 15614, 15621, 15628, 15635), class = "Date"), value = c(761726.58, 
857101.89, 749136.32, 1237956.68, 793981.61, 861052.71, 1740167.84, 
1348565.28, 1418102.37, 1244809.11, 2570026.85, 1072145.99, 953054.03, 
14215.44, 11587.59, 8896.44, 79055.33, 26668.41, 1991113.48, 
760008.1, 2366.41, 1960955.3, 2928948.74, 2215875.85, 2939086.3, 
3869296.31, 910097.65, 804338.73, 1648004.84, 1407837.26, 557153.11, 
1231785.66, 4430006.32, 1933735.74, 1733775.45, 1092611.43, 2586296.61, 
4215401.23, 989029.96, 1953652.01, 787519.23, 5492009.39, 1469597.12, 
1373534.49, 596375.34, 1467484.44, 2435976.86, 885934.08, 6523809.68, 
823400.97, 1939457.08, 464507.02, 1301133.33, 1374124.22, 1595500.29, 
2565051.31, 1845506.37, 3094490.26, 1326632.23, 767008.73, 697040.51, 
3522981.49, 1055205.33, 1512524.67, 1225637.5, 4461913.91, 807578.68, 
1025566.74, 1652269.52, 471748.58, 2501399.54, 2187112.61, 2460378.95, 
1640399.27, 2662477.2, 1077362.65, 2287778.59, 2247735.14, 1199470.58, 
1179229.13, 915205.03, 1864292.73, 2196493.17, 1219440.7, 576920.63, 
1651739.39, 3397835.24, 1224438.39, 4374050.83, 1815882.5, 2238561.63, 
4382539.55, 2026436.2, 10762505.13, 2202860.26, 980998.61, 1149598.09, 
1232106.7, 3592317.62, 867381.86, 4468397.64, 1145633.43, 1453154.82, 
1792573.53, 513029.49, 1274902.36, 4116335.16, 3435329.44, 1348027.01, 
2307152.14, 2281622.99, 1010530.08, 492632.7, 1522271.77, 522117.66, 
1087265.33, 4744783.09, 1875644.61, 1645967.28, 1160101.62, 1103553.74, 
668894.97, 532129.58, 5760909.29, 649484.14, 1355513.52, 1105582.38, 
2779436.47, 707437.29, 2814518.63, 3904727.33, 2007550.84, 592833.82, 
1106458.42, 2101013.07, 679443.13, 2342973.16, 2594914.41, 1313594.69, 
1816061.14, 813415.22, 1067061.86, 521107.66, 1244363.21, 977612.55, 
5067710.87, 3942903.86, 1267291.65, 634221.4, 2159533.7, 4415212.19, 
770794.16, 2812603.25, 1100106.06, 2583188.83, 950864.32, 922904.1, 
1431831.06, 2136347.26, 802885.62, 1867545.91, 2418341.5, 1337377.52, 
3989038.18, 4326916.99, 1628586.37, 2870183.88, 904918.85, 2459186.34, 
1283687.25, 1427404.27, 4836615.46, 1420714.78, 2433924.29, 714438.18, 
3343883.07, 4621820.27, 1935603.62, 767619.85, 4978707.68, 774006.62, 
2015113.66, 1679598.18, 1774966.46, 1128457.62, 1290245.53, 1660377.04, 
1003629.44, 2168572.82, 5083999.79, 2525852.71, 1679668.93, 932990.97, 
1419901.32, 2771279.76, 3428132.64, 1708623.96, 1549779.39, 982796.05, 
1012496.65, 5088335.32, 966540.48, 7963320.18, 1949377.92, 5210109.02, 
1082791.1, 2809864.15, 1589905.02, 1069575.06, 660136.82, 1811517.77, 
959474.99, 2956794.7, 1105908.93, 2333185.07, 3775967.1, 1008845.83, 
2792402.78, 3160232.32, 2125294, 2791000.82, 1805276.91, 5645546.83, 
1528778.23, 3165021.79, 2708298.01, 810602.46, 830353.84, 1647064.41, 
2904710.2, 946931.59, 2157189.04, 536283.04, 786015.66, 2136827.03, 
1700772.9, 3204220.16, 1339197.02, 1082632.61, 1098236.22, 1822219.24, 
3638890.87, 1945421.11, 2103100.44, 926220.3, 1714574.31, 1125085.31, 
835445.36, 6245495.97, 1687818.07, 2224868.84, 1078471.57)), class = c("tbl_time", 
"tbl_df", "tbl", "data.frame"), row.names = c(NA, -252L), index_quo = ~index, index_time_zone = "UTC")

0 个答案:

没有答案