如何在tf.data.Dataset.map中使用预训练的keras模型进行推理?

时间:2019-03-18 18:18:48

标签: python r tensorflow keras tensorflow-datasets

我有一个预先训练的模型,我正在尝试构建另一个模型,该模型将先前模型的输出作为输入。我不想端到端地训练模型,并且只想将第一个模型用于推理。第一个模型是使用tf.data.Dataset管道进行训练的,而我的第一个倾向是将该模型作为另一个dataset.map()操作集成到管道的尾部,但是我对此有疑问。在此过程中,我遇到了20个不同的错误,每个错误都与上一个错误无关。批处理规范化层似乎尤其是一个痛点。

下面是一个最小的入门示例,它说明了该问题。它用R编写,但也欢迎使用python回答。

我正在使用来自tf.keras的tensorflow-gpu版本1.13.1和keras

library(reticulate)
library(tensorflow)
library(keras)
library(tfdatasets)
use_implementation("tensorflow")

model_weights_path <- 'model-weights.h5'

arr <- function(...) 
  np_array(array(seq_len(prod(unlist(c(...)))), unlist(c(...))), dtype = 'float32')

new_model <- function(load_weights = TRUE) {
  model <- keras_model_sequential() %>% 
    layer_conv_1d(5, 5, activation = 'relu', input_shape = shape(150, 10)) %>%
    layer_batch_normalization() %>%
    layer_flatten() %>%
    layer_dense(10, activation = 'softmax')
  if (load_weights)
    load_model_weights_hdf5(model, model_weights_path)
  freeze_weights(model)
  model
}

if(!file.exists(model_weights_path)) {
  model <- new_model(FALSE) 
  save_model_weights_hdf5(model, model_weights_path)
}

model <- new_model()

data <- arr(20, 150, 10)
ds <- tfdatasets::tensors_dataset(data) %>% 
  dataset_repeat()

ds2 <- ds %>% 
  dataset_map(function(x) {
    model(x)
  })

try(nb <- next_batch(ds2))

sess <- k_get_session()
it <- make_iterator_initializable(ds2)
sess$run(iterator_initializer(it))
nb <- it$get_next()

try(sess$run(nb))

sess$run(tf$initialize_all_variables())

try(sess$run(nb))

1 个答案:

答案 0 :(得分:1)

也许这不会直接回答您的问题,因为我不熟悉 R。但我最近使用 tf.data 构建了一个输入管道。

generate_images 函数使用 .map 进行映射,并使用经过训练的生成器模型生成新图像。

gen_model = tf.keras.models.load_model(artifact_dir+'/'+generators[-1], compile=False)

NOISE_DIM = 100

def generate_images(l):
    # generate images using the trained generator
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    images = gen_model(noise)

    # prepare the images for resize_and_preprocess function
    images = tf.squeeze(images, axis=-1)
    images = images*0.5+0.5
    images = tf.image.convert_image_dtype(images, dtype=tf.uint8)

    return images

genloader = tf.data.Dataset.from_tensors([1])

genloader = (
    genloader
    .map(generate_images, num_parallel_calls=AUTO)
    .map(resize_and_preprocess, num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

关于批量归一化,它在训练和推理阶段表现不同。在基于 Python 的 TensorFlow 中,当使用具有批量归一化层的预训练模型时,需要传递 training=False