Keras for R:如何调试自定义损失函数(ValueError:尺寸必须相等,但是...)?

时间:2020-06-05 13:51:22

标签: python r tensorflow keras

我正在尝试为实值结果和bin概率的预测向量实现自定义损失函数(基本上是交叉熵损失的替代方法)。因此,my_loss输入y_true是一个实际值,y_predlength(bin_edges) - 1分箱(在下面的示例中为31)的概率向量, bin边bin_edges,长度为32的向量

这是(希望)带有模拟数据的可复制示例代码:

rm(list=ls())

library(keras)

n <- 5000
x <- matrix(rnorm(5*n, 0, 1), nrow = n)
beta1 <- matrix(rnorm(5*n, 0, 1), nrow = n)

y <- apply(x * beta1, 1, sum) 

bin_edges <- seq(-15, 16, 1)
y[y <= -15] <- -15
y[y > 16] <- 16

ind_train <- 1:(n-1000)
ind_test <- (max(ind_train)+1):n

x_train <- x[ind_train,]
x_test <- x[ind_test,]

y_train <- y[ind_train]
y_test <- y[ind_test]

my_loss <- function(y_true, y_pred, bin_edges){

  nb <- length(bin_edges)
  left_k <- k_constant(bin_edges[1:(nb-1)])
  right_k <- k_constant(bin_edges[2:nb])

  z_k <- k_cast_to_floatx(k_all(list(k_less_equal(left_k, y_true), k_less_equal(y_true, right_k)), axis = 1)) * y_true + 
    k_cast_to_floatx(k_all(list(k_less_equal(right_k, y_true), k_greater(y_true, left_k)), axis = 1)) * right_k + 
    k_cast_to_floatx(k_all(list(k_less_equal(y_true, left_k), k_less(y_true, right_k)), axis = 1)) * left_k

  min_left_k <- k_constant(bin_edges[1])
  max_right_k <- k_max(right_k)

  z_transf_k <- z_k +
    k_cast_to_floatx(k_less(y_true, min_left_k)) * (y_true * k_ones_like(left_k) - z_k) * k_concatenate(list(k_ones(c(1)),k_zeros(c(length(left_k) - 1)))) +
    k_cast_to_floatx(k_greater(y_true, max_right_k)) * (y_true * k_ones_like(left_k) - z_k) * k_concatenate(list(k_zeros(c(length(left_k) - 1)), k_ones(c(1))))

  L_k <- k_gather(k_cumsum(k_concatenate(list(k_zeros_like(y_true), y_pred))), 1:(length(left_k)))
  U_k <- k_ones(c(length(left_k)) )- k_cumsum(y_pred)

  w_k <- (z_transf_k - left_k) / (right_k - left_k)

  punif_w_k <- k_clip(w_k, 0, 1)
  k_sum((right_k - left_k) * (k_abs(w_k - punif_w_k) + 
                                (k_ones(c(length(left_k))) - L_k - U_k) * k_square(punif_w_k) - 
                                punif_w_k * (k_ones(c(length(left_k))) - k_constant(2)*L_k) + 
                                k_square((k_ones(c(length(left_k))) - L_k - U_k)) / k_constant(3) + 
                                (k_ones(c(length(left_k))) - L_k) * U_k))
}

model <- keras_model_sequential() 
model %>% 
  layer_dense(units = 100, activation = 'relu', input_shape = c(5)) %>% 
  layer_dense(units = length(bin_edges) - 1, activation = 'softmax')

model %>% compile(
  optimizer = optimizer_adam(),
  metrics = c('accuracy'),
  loss = function(y_true, y_pred)
    my_loss(y_true, y_pred, bin_edges)
)

summary(model)

history <- model %>% fit(
  x_train, y_train,
  epochs = 10, batch_size = 1024,
  validation_split = 0.2
)

执行此操作将导致以下错误消息:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  RuntimeError: Evaluation error: ValueError: Dimensions must be equal, but are 31 and 32 for 'loss/dense_115_loss/Sub_7' (op: 'Sub') with input shapes: [31], [31,32]..

有什么办法可以找出问题所在吗?特别是,我不知道形状[31,32]的来源...本质上,y_pred应该是概率向量,总和为1,长度等于length(bin_edges) - 1,所以31。

此外,错误消息中的(op: 'Sub')代表什么或表示什么?

我将损失函数的输出与R实现进行了比较,数值是正确的。

devtools::session_info()的输出是:

─ Session info ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 setting  value                       
 version  R version 3.6.3 (2020-02-29)
 os       Ubuntu 20.04 LTS            
 system   x86_64, linux-gnu           
 ui       RStudio                     
 language (EN)                        
 collate  en_US.UTF-8                 
 ctype    en_US.UTF-8                 
 tz       Europe/Berlin               
 date     2020-06-01                  

─ Packages ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 package     * version      date       lib source                        
 assertthat    0.2.1        2019-03-21 [1] CRAN (R 3.6.3)                
 backports     1.1.7        2020-05-13 [1] CRAN (R 3.6.3)                
 base64enc     0.1-3        2015-07-28 [1] CRAN (R 3.6.3)                
 callr         3.4.3        2020-03-28 [1] CRAN (R 3.6.3)                
 cli           2.0.2        2020-02-28 [1] CRAN (R 3.6.3)                
 crayon        1.3.4        2017-09-16 [1] CRAN (R 3.6.3)                
 desc          1.2.0        2018-05-01 [1] CRAN (R 3.6.3)                
 devtools      2.3.0        2020-04-10 [1] CRAN (R 3.6.3)                
 digest        0.6.25       2020-02-23 [1] CRAN (R 3.6.3)                
 ellipsis      0.3.1        2020-05-15 [1] CRAN (R 3.6.3)                
 fansi         0.4.1        2020-01-08 [1] CRAN (R 3.6.3)                
 fs            1.4.1        2020-04-04 [1] CRAN (R 3.6.3)                
 generics      0.0.2        2018-11-29 [1] CRAN (R 3.6.2)                
 glue          1.4.1        2020-05-13 [1] CRAN (R 3.6.3)                
 jsonlite      1.6.1        2020-02-02 [1] CRAN (R 3.6.3)                
 keras       * 2.3.0.0.9000 2020-05-22 [1] Github (rstudio/keras@561e748)
 lattice       0.20-41      2020-04-02 [4] CRAN (R 3.6.3)                
 magrittr      1.5          2014-11-22 [1] CRAN (R 3.6.1)                
 Matrix        1.2-18       2019-11-27 [4] CRAN (R 3.6.1)                
 memoise       1.1.0        2017-04-21 [1] CRAN (R 3.6.3)                
 pkgbuild      1.0.8        2020-05-07 [1] CRAN (R 3.6.3)                
 pkgload       1.0.2        2018-10-29 [1] CRAN (R 3.6.3)                
 prettyunits   1.1.1        2020-01-24 [1] CRAN (R 3.6.3)                
 processx      3.4.2        2020-02-09 [1] CRAN (R 3.6.3)                
 ps            1.3.3        2020-05-08 [1] CRAN (R 3.6.3)                
 R6            2.4.1        2019-11-12 [1] CRAN (R 3.6.3)                
 rappdirs      0.3.1        2016-03-28 [1] CRAN (R 3.6.3)                
 Rcpp          1.0.4.6      2020-04-09 [1] CRAN (R 3.6.3)                
 remotes       2.1.1        2020-02-15 [1] CRAN (R 3.6.3)                
 reticulate    1.15         2020-04-02 [1] CRAN (R 3.6.3)                
 rlang         0.4.6        2020-05-02 [1] CRAN (R 3.6.3)                
 rprojroot     1.3-2        2018-01-03 [1] CRAN (R 3.6.3)                
 rstudioapi    0.11         2020-02-07 [1] CRAN (R 3.6.3)                
 sessioninfo   1.1.1        2018-11-05 [1] CRAN (R 3.6.3)                
 tensorflow    2.2.0        2020-05-11 [1] CRAN (R 3.6.3)                
 testthat      2.3.2        2020-03-02 [1] CRAN (R 3.6.3)                
 tfruns        1.4          2018-08-25 [1] CRAN (R 3.6.3)                
 usethis       1.6.1        2020-04-29 [1] CRAN (R 3.6.3)                
 whisker       0.4          2019-08-28 [1] CRAN (R 3.6.3)                
 withr         2.2.0        2020-04-20 [1] CRAN (R 3.6.3)                
 zeallot       0.1.0        2018-01-28 [1] CRAN (R 3.6.3) 

0 个答案:

没有答案