我正在尝试在训练期间的每个时期之后进行验证。
我按如下方式创建图表:
import tensorflow as tf
from networks import densenet
from networks.densenet_utils import dense_arg_scope
with tf.variable_scope('scope') as scope:
with slim.arg_scope(dense_arg_scope()):
logits_train, _ = densenet(images, blocks=networks[
'densenet_265'], num_classes=1000, data_name='imagenet', is_training=True, scope='densenet265',
reuse=tf.AUTO_REUSE)
scope.reuse_variables()
with slim.arg_scope(dense_arg_scope()):
logits_val, _ = densenet(images, blocks=networks[
'densenet_265'], num_classes=1000, data_name='imagenet', is_training=False, scope='densenet265',
reuse=tf.AUTO_REUSE)
为了在培训或验证期间获得logits
,我会执行以下操作:
is_training = tf.Variable(True, trainable=False, dtype=tf.bool)
training_mode = tf.assign(is_training, True)
validation_mode = tf.assign(is_training, False)
logits = tf.cond(tf.equal(is_training, tf.constant(True, dtype=tf.bool)), lambda: logits_train,
lambda: logits_val)
然而,当我运行我的代码时,我收到OOM错误。我确信这不是因为批量大。这是因为,之前我犯了一个大错,并且在训练和验证过程中使用了相同的图表。当时批量大小为32
且图片大小为224x224x3
,代码运行得非常好。
我怀疑在使用is_training=False
进行验证时尝试重用图表时出现了一些错误。
密集网络的代码取自以下两个文件: densenet_utils.py densenet.py
答案 0 :(得分:2)
您在logits_train和logits_val中创建了两个独立的网络,因此这会占用网络内存的两倍。 (我假设它正确设置并且变量正确共享,这可能是另一个问题,但这不会导致OOM,大数据是激活,而不是权重。)
没有必要这样做。使用相同的网络logits_train
进行验证。事实证明参数is_training
也可以采用布尔标量张量,因此您可以动态切换训练或推理模式。
所以,在您设置images
占位符的位置,请将此行作为下一行:
training_mode = tf.placeholder( shape = None, dtype = tf.bool )
然后在上面的代码中,像这样设置您的网络:
logits_train, _ = densenet(images, blocks=networks['densenet_265'],
num_classes=1000, data_name='imagenet', is_training=training_mode,
scope='densenet265', reuse=tf.AUTO_REUSE)
请注意,is_training
参数的值填充了上面的张量training_mode
!
然后当您执行sess.run( [ ... ] )
命令(在上面的代码中不可见)时,您应该在training_mode
中包含feed_dict
,如此(伪代码):
result = sess.run( [ ??? ], feed_dict = { images : ???, training_mode : True / False } )
请注意,根据您是否正在接受培训,training_mode
张量现在填充了False或True。
这是基于我对 batch_normalization 和 dropout 图层的研究。