为什么tf.cond似乎运行两个分支?

时间:2018-05-08 01:25:30

标签: tensorflow

为了在同一过程中轻松切换培训和验证,我决定在我的图表定义中使用tf.cond

考虑以下用于设计TF图的类结构:

import tensorflow as tf
class OverFeatAccurateBase(object):
    def __init__(self, input, numclasses, trainmode):
        self._numclasses = numclasses
        self._trainmode = trainmode
        self._logits = self._buildmodel(input)

    @property
    def numclasses(self):
        return self._numclasses

    def setmode(self, val):
        self._trainmode = val

    @property
    def mode(self):
        return self._trainmode

    @property
    def logits(self):
        return self._logits

    def _buildmodel(self, input):
        out = tf.layers.conv2d(input, filters=96,
                               kernel_size=[7, 7],
                               strides=[2, 2],
                               padding='valid',
                               data_format='channels_last',
                               activation=tf.nn.relu,
                               kernel_initializer=tf.initializers.random_normal(
                                   stddev=0.01,
                                   seed=0),
                               bias_initializer=tf.initializers.constant(0),
                               kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                   scale=10 ** (-5)),
                               reuse=tf.AUTO_REUSE,
                               name='conv1')

        out = tf.layers.max_pooling2d(out, pool_size=[3, 3],
                                      strides=[3, 3],
                                      padding='valid',
                                      data_format='channels_last',
                                      name='pool1')

        out = tf.layers.conv2d(out, filters=256,
                               kernel_size=[7, 7],
                               strides=[1, 1],
                               padding='valid',
                               data_format='channels_last',
                               activation=tf.nn.relu,
                               kernel_initializer=tf.initializers.random_normal(
                                   stddev=0.01,
                                   seed=0),
                               bias_initializer=tf.initializers.constant(0),
                               kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                   scale=10 ** (-5)),
                               reuse=tf.AUTO_REUSE,
                               name='conv2')

        out = tf.layers.max_pooling2d(out, pool_size=[2, 2],
                                      strides=[2, 2],
                                      padding='valid',
                                      data_format='channels_last',
                                      name='pool2')

        out = tf.layers.conv2d(out, filters=512,
                               kernel_size=[3, 3],
                               strides=[1, 1],
                               padding='same',
                               data_format='channels_last',
                               activation=tf.nn.relu,
                               kernel_initializer=tf.initializers.random_normal(
                                   stddev=0.01,
                                   seed=0),
                               bias_initializer=tf.initializers.constant(0),
                               kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                   scale=10 ** (-5)),
                               reuse=tf.AUTO_REUSE,
                               name='conv3')

        out = tf.layers.conv2d(out, filters=512,
                               kernel_size=[3, 3],
                               strides=[1, 1],
                               padding='same',
                               data_format='channels_last',
                               activation=tf.nn.relu,
                               kernel_initializer=tf.initializers.random_normal(
                                   stddev=0.01,
                                   seed=0),
                               bias_initializer=tf.initializers.constant(0),
                               kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                   scale=10 ** (-5)),
                               reuse=tf.AUTO_REUSE,
                               name='conv4')

        out = tf.layers.conv2d(out, filters=1024,
                               kernel_size=[3, 3],
                               strides=[1, 1],
                               padding='same',
                               data_format='channels_last',
                               activation=tf.nn.relu,
                               kernel_initializer=tf.initializers.random_normal(
                                   stddev=0.01,
                                   seed=0),
                               bias_initializer=tf.initializers.constant(0),
                               kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                   scale=10 ** (-5)),
                               reuse=tf.AUTO_REUSE,
                               name='conv5')

        out = tf.layers.conv2d(out, filters=1024,
                               kernel_size=[3, 3],
                               strides=[1, 1],
                               padding='same',
                               data_format='channels_last',
                               activation=tf.nn.relu,
                               kernel_initializer=tf.initializers.random_normal(
                                   stddev=0.01,
                                   seed=0),
                               bias_initializer=tf.initializers.constant(0),
                               kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                   scale=10 ** (-5)),
                               reuse=tf.AUTO_REUSE,
                               name='conv6')

        out = tf.layers.max_pooling2d(out, pool_size=[3, 3],
                                      strides=[3, 3],
                                      padding='valid',
                                      data_format='channels_last',
                                      name='pool3')

        out = tf.layers.flatten(out, name='flatten')

        out = tf.layers.dense(out, units=4096, activation=tf.nn.relu,
                              kernel_initializer=tf.initializers.random_normal(
                                  stddev=0.01,
                                  seed=0),
                              bias_initializer=tf.initializers.constant(0),
                              kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                  scale=10 ** (-5)),
                              reuse=tf.AUTO_REUSE,
                              name='full1'
                              )

        out = tf.cond(tf.equal(self.mode, tf.constant(True)),
                      lambda: tf.layers.dropout(
                          out, seed=0), lambda: tf.Print(out,
                                                         [out],
                                                         'The '
                                                         'shape '
                                                         'is'))

        out = tf.layers.dense(out, units=4096, activation=tf.nn.relu,
                              kernel_initializer=tf.initializers.random_normal(
                                  stddev=0.01,
                                  seed=0),
                              bias_initializer=tf.initializers.constant(0),
                              kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                  scale=10 ** (-5)),
                              reuse=tf.AUTO_REUSE,
                              name='full2'
                              )
        out = tf.cond(tf.equal(self.mode, tf.constant(True)),
                      lambda: tf.layers.dropout(
                          out, seed=0), lambda: out)
        logits = tf.layers.dense(out, units=self.numclasses,
                                 activation=tf.nn.relu,
                                 kernel_initializer=tf.initializers.random_normal(
                                     stddev=0.01,
                                     seed=0),
                                 bias_initializer=tf.initializers.constant(0),
                                 kernel_regularizer=tf.contrib.layers.l2_regularizer(
                                     scale=10 ** (-5)),
                                 reuse=tf.AUTO_REUSE,
                                 name='output'
                                 )

        return logits

现在让我们使用以下代码片段

测试上述结构
from networks.overfeataccuratebase import OverFeatAccurateBase
import tensorflow as tf
import numpy as np

inp = np.random.randn(10,221,221,3)


input = tf.placeholder(dtype=tf.float32, shape=(None, 221, 221, 3),
                       name='input')

mode_train = tf.constant(True)

mode_val = tf.constant(False)

net = OverFeatAccurateBase(input, 1000, mode_train)

logits = net.logits

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    writer = tf.summary.FileWriter('./tboard', graph=sess.graph)
    sess.run(init_op)
    print(sess.run(logits, feed_dict={input: inp}))
    net.setmode(mode_val)
    print(sess.run(net.mode))
    print(sess.run(logits, feed_dict={input: inp}))
    writer.close()

在运行上述代码段时,可以看到,尽管设置了net.setmode(mode_val),但图表仍然在培训模式下运行,因为tf.Print中的tf.cond语句节点未执行。我错过了什么?

1 个答案:

答案 0 :(得分:1)

tf.cond的工作方式是它运行if语句的两个分支,然后确保正确的分支是将其值分配给输出的那个分支。这就是为什么你看到打印声明出现在你不期望的时候。

看起来tf.cond语句的目的只是启用或禁用dropout。我在自己的代码中这样做的方法是使丢失概率为占位符,默认值为1.0。然后在训练期间,我输入适当的丢失概率,在验证/测试期间,我保留默认值,并有效地禁用了丢失。