我一直在努力使用我训练和保存的批范数来加载Tensorflow CNN模型的权重。
我尝试在模型定义中同时使用tf.layers.batch_normalization和tf.contrib.layers.batch_norm。
当前的方法似乎可行,但杂乱无章,它是执行以下操作:
在定义批处理规范层时(使用contrib API),请如下添加变量集合:
with tf.variable_scope('g_weights', reuse=tf.AUTO_REUSE):
#model definition...
conv1_norm = tf.contrib.layers.batch_norm(conv1, is_training=training, \
variables_collections=["g_batch_norm_non_trainable"])
然后将所述变量集合包括在模型的参数集合中:
t_vars = tf.trainable_variables()
g_vars = list(set([var for var in t_vars if 'g_' in var.name] + tf.get_collection("g_batch_norm_non_trainable")))
...
g_saver = tf.train.Saver(g_vars)
... train model...
g_saver.save(sess, "weights/generator/gen.ckpt")
然后加载,使用相同的方法收集g_vars
,然后
g_saver.restore(sess, "./weights/generator/gen.ckpt")
我在这里问了一个先前的问题,一个可能的解决方案是
g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_weights')
但是,当重新加载砝码时,我得到了错误
key error g_weights/BatchNorm/beta/Adam not in checkpoint
有没有更简单的方法来保存使用批处理规范的模型?在PyTorch和Keras中,这非常简单,但是在Tensorflow中似乎没有很好的解决方案。