Keras自定义图层中的权重不可训练

时间:2020-09-24 21:06:05

标签: r keras

我尝试创建一个rStudio / Keras定制层,包括其他keras层。我想像在GitHub https://github.com/rstudio/keras/issues/926上的本期中那样仅使用一个密集层来尝试它,但是除了自定义层之外,模型中没有其他层。我在非常简单的虚拟数据上进行了训练,但是在我的示例中,权重不仅没有出现在模型摘要中,而且也没有得到训练。

library(keras)
CustomDense <- R6::R6Class(
  "CustomDense",
  inherit = KerasLayer,
  lock_objects = FALSE,
  
   public = list(
    initialize = function(units) {
      self$units <- units
    },
    
    build = function(input_shape) {
      self$dense <- keras::layer_dense(units = self$units)
      super$build(input_shape)
    },
    
    call = function(input, mask = NULL) {
      output <- self$dense(input) 
    },
    
    compute_output_shape = function(input_shape) {
      input_shape[[2]] <- self$units 
      input_shape
    }
  )
)

custom_layer_dense <- function(object, 
                               units,
                               name = NULL, 
                               trainable = TRUE) {
  create_layer(CustomDense, 
               object, 
               args = list(units = units,
                           name = name,
                           trainable = trainable))
}

library(keras)
in_layer <- layer_input(shape=c(10L))
mid_layer <- custom_layer_dense(units=1L)
out_layer <- in_layer %>%
  mid_layer

model_custom <- keras_model(inputs=in_layer,outputs=out_layer)

model_custom %>% compile(
  optimizer = 'adam',
  loss = 'mae'
)

x_train <- array(rep(1,100),dim=c(10,10))
y_train <- array(rep(2,10),dim=c(10,1))

model_custom %>% fit(x_train,
                     y_train,
                     epochs=20)

在训练期间,损耗值保持完全相同,并且layer_dense的权重不变。但是您可以调用model_custom并收到输出。

>predict_on_batch(model_custom,x_train)
[,1]  [1,] -0.2629238  [2,] -0.2629238  [3,] -0.2629238  [4,] -0.2629238  [5,] -0.2629238  [6,] -0.2629238  [7,] -0.2629238  [8,] -0.2629238  [9,] -0.2629238 [10,] -0.2629238

我是否缺少自定义图层的重要内容?还是这个问题可以解决?

谢谢!

注意:尽管我使用R,但是我可以或多或少地将Python转换为R。因此,如果这是一个编码问题,也欢迎使用Python答案;)

注意:如果此代码适用于其他人,这也将很有趣,因此我知道我的版本存在问题。

会议信息: R版本3.6.1 keras_2.3.0.0 tensorflow_2.2.0 reticulate_1.15

0 个答案:

没有答案