具有is_training真和假的Tensorflow(tf-slim)模型

时间:2016-09-06 16:13:03

标签: tensorflow tf-slim

我想在列车集(is_training=True)和验证集(is_training=False)上运行给定模型,特别是如何应用dropout。现在,prebuilt models公开了一个参数is_training,在构建网络时将其传递给dropout图层。问题是,如果我使用不同的is_training值调用该方法两次,我会得到两个不共享权重的网络(我认为?)。如何让两个网络共享相同的权重,以便我可以运行我在验证集上训练过的网络?

3 个答案:

答案 0 :(得分:1)

我写了一个解决方案,你的评论是在火车和测试模式下使用Overfeat。 (我无法测试它,所以你可以检查它是否有效?)

首先是一些导入和参数:

import tensorflow as tf
slim = tf.contrib.slim
overfeat = tf.contrib.slim.nets.overfeat

batch_size = 32
inputs = tf.placeholder(tf.float32, [batch_size, 231, 231, 3])
dropout_keep_prob = 0.5
num_classes = 1000

在训练模式中,我们将正常范围传递给函数overfeat

scope = 'overfeat'
is_training = True

output = overfeat.overfeat(inputs, num_classes, is_training,         
                           dropout_keep_prob, scope=scope)

然后在测试模式下,我们使用reuse=True创建相同的范围。

scope = tf.VariableScope(reuse=True, name='overfeat')
is_training = False

output = overfeat.overfeat(inputs, num_classes, is_training,         
                           dropout_keep_prob, scope=scope)

答案 1 :(得分:0)

你可以使用占位符进行is_training:

isTraining = tf.placeholder(tf.bool)

# create nn
net = ...
net = slim.dropout(net,
                   keep_prob=0.5,
                   is_training=isTraining)
net = ...

# training
sess.run([net], feed_dict={isTraining: True})

# testing
sess.run([net], feed_dict={isTraining: False})

答案 2 :(得分:0)

这取决于具体情况,解决方案是不同的。

我的第一个选择是使用不同的流程来进行评估。您只需要检查是否有新的检查点并将权重加载到评估网络中(使用is_training=False):

checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
# wait until a new check point is available
while self.lastest_checkpoint == checkpoint:
    time.sleep(30)  # sleep 30 seconds waiting for a new checkpoint
    checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
logging.info('Restoring model from {}'.format(checkpoint))
self.saver.restore(session, checkpoint)
self.lastest_checkpoint = checkpoint

第二个选项是在每个纪元后卸载图表并创建新的评估图表。这种解决方案浪费了大量时间来加载和卸载图表。

第三种选择是分享权重。但是,使用队列或数据集为这些网络提供服务可能会导致问题,因此您必须非常小心。我只将它用于Siamese网络。

with tf.variable_scope('the_scope') as scope:
    your_model(is_training=True)
    scope.reuse_variables()
    your_model(is_training=False)