无法恢复检查点TensorFlow网络

时间:2016-08-19 09:01:49

标签: tensorflow

我已经构建了一个自动编码器,可以将激活从VGG19.relu4_1“转换”为像素。我在tensorflow.contrib.layers中使用了新的便利函数(如在TF 0.10rc0中)。代码与TensorFlow的CIFAR10教程具有类似的布局,其中train.py执行培训并将模型检查到磁盘,一个eval.py轮询新检查点文件并对其进行推理。

我的问题是评估从来没有像训练一样好,既不是在损失函数的价值方面,也不是在我看输出图像时(即使在与训练相同的图像上运行)。这让我觉得它与恢复过程有关。

当我查看TensorBoard培训的输出时,它看起来很好(最终)所以我认为我的网本身没有任何问题。

我的网看起来像这样:

import tensorflow.contrib.layers as contrib
bn_params = {                                                                             
    "is_training": is_training,
    "center": True,
    "scale": True
}                                                                                                                                                       

tensor = contrib.convolution2d_transpose(vgg_output, 64*4, 4,                               
    stride=2,
    normalizer_fn=contrib.batch_norm,
    normalizer_params=bn_params,
    scope="deconv1")                                                        
tensor = contrib.convolution2d_transpose(tensor, 64*2, 4,                               
    stride=2,
    normalizer_fn=contrib.batch_norm,
    normalizer_params=bn_params,
    scope="deconv2")
.
.
.

train.py我这样做是为了保存检查点:

variable_averages = tf.train.ExponentialMovingAverage(mynet.MOVING_AVERAGE_DECAY)
variables_averages_op = variable_averages.apply(tf.trainable_variables())

with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
    train_op = tf.no_op(name='train')

while training:
    # train (with batch normalization's is_training = True)
    if time_to_checkpoint:
        saver.save(sess, checkpoint_path, global_step=step)

eval.py我这样做:

# run code that creates the net

variable_averages = tf.train.ExponentialMovingAverage(
                  mynet.MOVING_AVERAGE_DECAY)
saver = tf.train.Saver(variable_averages.variables_to_restore())

while polling:
    # sleep and check for new checkpoint files
    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        init_local = tf.initialize_local_variables()
        sess.run([init, init_local])
        saver.restore(sess, checkpoint_path)

        # run inference (with batch normalization's is_training = False)

The loss function

蓝色是训练损失,橙色是评估损失。

1 个答案:

答案 0 :(得分:1)

问题是我直接使用了tf.train.AdamOptimizer()。在优化期间,它没有调用contrib.batch_norm中定义的操作来计算输入的运行均值/方差,因此均值/方差始终为0.0 / 1.0。

解决方案是向GraphKeys.UPDATE_OPS集合添加依赖项。 contrib模块中已经定义了一个功能(optimize_loss()