保存/恢复批次规范层(TensorFlow)的问题

时间:2018-12-27 06:44:58

标签: tensorflow

我正在尝试一个其中具有多个批处理规范层的模型。问题是,当我还原模型时,似乎正在为批范数层选择随机权重(而不是出于训练的总体权重),并产生不良的测试输出。

我查看了在线发布的多种解决方案,但还没有走运。请看下面的代码(包含我尝试过的解决方案)

我环顾了stackoverflow,并尝试了各种解决方案,包括添加批处理规范依赖项并确保指定了保护程序var_list。似乎没有任何作用。

我正在使用Tensorflow版本1.10,并且还尝试使用save_npz和load_and_assign_npz(在那里存在相同问题)。

这是我模型的一部分看起来像一个想法-

1.model.py

def _batch_normalization(input_tensor, is_train, gamma_init, name):
    return tf.layers.batch_normalization(input_tensor, training=is_train, name=name)

gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("u_net_bn", reuse=reuse):
    tl.layers.set_name_reuse(reuse)
    inputs = InputLayer(x, name='inputs')

    conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, name='conv1_1')

    conv1 = _batch_normalization(conv1.outputs, is_train=is_train, gamma_init=gamma_init, name='bn1')

    conv1 = InputLayer(conv1, name='bn1_fix')

    conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='conv1_2')

    conv1 = _batch_normalization(conv1.outputs, is_train=is_train, gamma_init=gamma_init, name='bn2')

    conv1 = InputLayer(conv1, name='bn2_fix')

    pool1 = MaxPool2d(conv1, (2, 2), name='pool1')

    conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, name='conv2_1')

2.main.py

 with tf.control_dependencies(extra_update_ops):

     ## Pretrain

     g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)

     ## SRGAN

     g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)

     d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)

保存-

saver = tf.train.Saver(var_list=tf.global_variables())
savepath = saver.save(sess, checkpoint_dir + '/g_{}_{}_init.npz'.format(tl.global_flag['mode'], epoch))

Restore --

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))

saver = tf.train.Saver()

saver.restore(sess,tf.train.latest_checkpoint('./checkpoint'))

1 个答案:

答案 0 :(得分:0)

在不查看完整源代码的情况下,很难说出您的特定情况出了什么问题。保存/还原BN层似乎正常。

>>> import tensorflow as tf
>>> tf.__version__
'1.10.0'
>>> bn = tf.layers.batch_normalization(tf.placeholder(tf.float32, shape=(1, 1)))
>>> init = tf.global_variables_initializer()
>>> sess = tf.Session()
>>> sess.run(init)
>>> sess.run(tf.global_variables())
[array([1.], dtype=float32), array([0.], dtype=float32), array([0.], dtype=float32), array([1.], dtype=float32)]

请注意,所有全局变量均已由BN层创建。训练循环将进行增量更新,可以通过以下方式进行模拟。

>>> for v in tf.global_variables():
...     sess.run(v.assign([42]))
... 
array([42.], dtype=float32)
array([42.], dtype=float32)
array([42.], dtype=float32)
array([42.], dtype=float32)

所有变量现在等于[42]。是时候保存/恢复了。

>>> saver = tf.train.Saver(tf.global_variables())
>>> saver.save(sess, "/tmp/chkpt")
'/tmp/chkpt'
>>> sess.run(init)  # Reinitialize all variables.
>>> sess.run(tf.global_variables())
[array([1.], dtype=float32), array([0.], dtype=float32), array([0.], dtype=float32), array([1.], dtype=float32)]

最后两行确保将变量重置为其初始值。最后

>>> saver.restore(sess, "/tmp/chkpt")
INFO:tensorflow:Restoring parameters from /tmp/chkpt
>>> sess.run(tf.global_variables())
[array([42.], dtype=float32), array([42.], dtype=float32), array([42.], dtype=float32), array([42.], dtype=float32)]