使用Tensorflow API的批量标准化时,保存和加载权重的正确方法是什么?

时间:2019-02-05 16:34:27

标签: python tensorflow deep-learning

我一直在努力使用我训练和保存的批范数来加载Tensorflow CNN模型的权重。

我尝试在模型定义中同时使用tf.layers.batch_normalization和tf.contrib.layers.batch_norm。

当前的方法似乎可行,但杂乱无章,它是执行以下操作:

在定义批处理规范层时(使用contrib API),请如下添加变量集合:

with tf.variable_scope('g_weights', reuse=tf.AUTO_REUSE):
    #model definition...
    conv1_norm = tf.contrib.layers.batch_norm(conv1, is_training=training, \
             variables_collections=["g_batch_norm_non_trainable"])

然后将所述变量集合包括在模型的参数集合中:

t_vars = tf.trainable_variables()
g_vars = list(set([var for var in t_vars if 'g_' in var.name] + tf.get_collection("g_batch_norm_non_trainable")))
...
g_saver = tf.train.Saver(g_vars)
... train model...
g_saver.save(sess, "weights/generator/gen.ckpt")

然后加载,使用相同的方法收集g_vars,然后

g_saver.restore(sess, "./weights/generator/gen.ckpt")

我在这里问了一个先前的问题,一个可能的解决方案是

g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_weights')

但是,当重新加载砝码时,我得到了错误

key error g_weights/BatchNorm/beta/Adam not in checkpoint

有没有更简单的方法来保存使用批处理规范的模型?在PyTorch和Keras中,这非常简单,但是在Tensorflow中似乎没有很好的解决方案。

0 个答案:

没有答案