R Shiny无法识别reactVal的突变

时间:2019-06-10 11:20:30

标签: r keras shiny

将Keras模型声明为reactiveVal会导致Shiny在调用fit()时发生突变时无法识别。这是什么原因,最好的解决方法是什么?是否可以手动告诉Shiny某些反应变量已被突变?

Keras模型通常是就地突变的,我怀疑在这种情况下,只要在训练过程中更新模型的参数,就会导致Shiny丢失。还是可能是Keras是TensorFlow的接口(即代码的某些部分在R和Shiny的范围之外执行)?

在此独立的MWE中,显示任意权重。尽管模型是经过训练的,但显示的值永远不会更新(重量的确会改变,并且在每次训练之后都会打印到壳体上。)

library(keras)
library(shiny)

# Load data, reshape and normalize inputs and one-hot-encode labels
mnist <- dataset_mnist()
x.train <- array_reshape(mnist$train$x, c(nrow(mnist$train$x), 784)) / 255
y.train <- to_categorical(mnist$train$y, 10)

# Returns untrained Keras model. Don't worry about the details.
initialize_model <- function() {
  keras_model_sequential()                                                    %>%
  layer_dense(units = 5, activation = 'relu', input_shape = c(784))           %>%
  layer_dense(units = 10, activation = 'softmax')                             %>%
  compile(loss = 'categorical_crossentropy', optimizer = optimizer_rmsprop())
}

ui <- fluidPage(
  fluidRow(
    column(width = 12, align = "center",
      actionButton('train.model', label = 'Train Model'),
      br(), br(),
      code(textOutput('random.weight'))
    )
  )
)

server <- function(input, output) {

  MODEL         <- reactiveVal(initialize_model())
  RANDOM.WEIGHT <- reactive({ return(MODEL()$get_weights()[[3]][1,2]) })

  observeEvent(input$train.model, {
    # this causes the model object to be mutated in-place but RANDOM.WEIGHT
    # is never updated
    MODEL() %>% fit(x.train, y.train, epochs = 1, batch_size = 60000)
    cat("The weight's new value is = ", MODEL()$get_weights()[[3]][1,2], "\n")
  })
  output$random.weight <- renderText({ 
    return(paste0("RANDOM.WEIGHT() = ", RANDOM.WEIGHT()))
  })
}

app <- shinyApp(ui = ui, server = server)
runApp(app, port = 1337)

有什么想法吗? 谢谢


回复评论:

dput(initialize_model())的输出是什么?

structure(function (object)
{
    compose_layer(object, x)
}, class = c("keras.engine.sequential.Sequential",        
             "keras.engine.training.Model",
             "keras.engine.network.Network",
             "keras.engine.base_layer.Layer",              
             "tensorflow.python.training.checkpointable.base.CheckpointableBase",
             "python.builtin.object"), py_object = <environment>)

不幸的是,这并不能告诉我很多。

0 个答案:

没有答案