这是在训练期间重用tf.slim图进行验证的正确方法吗?

时间:2018-04-01 13:48:17

标签: tensorflow

我正在尝试在训练期间的每个时期之后进行验证。

我按如下方式创建图表:

import tensorflow as tf
from networks import densenet
from networks.densenet_utils import dense_arg_scope

with tf.variable_scope('scope') as scope:
    with slim.arg_scope(dense_arg_scope()):
        logits_train, _ = densenet(images, blocks=networks[
            'densenet_265'], num_classes=1000, data_name='imagenet', is_training=True, scope='densenet265',
                                 reuse=tf.AUTO_REUSE)
    scope.reuse_variables()
    with slim.arg_scope(dense_arg_scope()):
        logits_val, _ = densenet(images, blocks=networks[
            'densenet_265'], num_classes=1000, data_name='imagenet', is_training=False, scope='densenet265',
                                 reuse=tf.AUTO_REUSE)

为了在培训或验证期间获得logits,我会执行以下操作:

is_training = tf.Variable(True, trainable=False, dtype=tf.bool)
training_mode = tf.assign(is_training, True)
validation_mode = tf.assign(is_training, False)
logits = tf.cond(tf.equal(is_training, tf.constant(True, dtype=tf.bool)), lambda: logits_train,
                     lambda: logits_val)

然而,当我运行我的代码时,我收到OOM错误。我确信这不是因为批量大。这是因为,之前我犯了一个大错,并且在训练和验证过程中使用了相同的图表。当时批量大小为32且图片大小为224x224x3,代码运行得非常好。

我怀疑在使用is_training=False进行验证时尝试重用图表时出现了一些错误。

密集网络的代码取自以下两个文件: densenet_utils.py densenet.py

1 个答案:

答案 0 :(得分:2)

您在logits_train和logits_val中创建了两个独立的网络,因此这会占用网络内存的两倍。 (我假设它正确设置并且变量正确共享,这可能是另一个问题,但这不会导致OOM,大数据是激活,而不是权重。)

没有必要这样做。使用相同的网络logits_train进行验证。事实证明参数is_training也可以采用布尔标量张量,因此您可以动态切换训练或推理模式。

所以,在您设置images占位符的位置,请将此行作为下一行:

training_mode = tf.placeholder( shape = None, dtype = tf.bool )

然后在上面的代码中,像这样设置您的网络:

logits_train, _ = densenet(images, blocks=networks['densenet_265'],
    num_classes=1000, data_name='imagenet', is_training=training_mode,
    scope='densenet265', reuse=tf.AUTO_REUSE)

请注意,is_training参数的值填充了上面的张量training_mode

然后当您执行sess.run( [ ... ] )命令(在上面的代码中不可见)时,您应该在training_mode中包含feed_dict,如此(伪代码):

result = sess.run( [ ??? ], feed_dict = { images : ???, training_mode : True / False } )

请注意,根据您是否正在接受培训,training_mode张量现在填充了False或True。

这是基于我对 batch_normalization dropout 图层的研究。