为了在同一过程中轻松切换培训和验证,我决定在我的图表定义中使用tf.cond
。
考虑以下用于设计TF图的类结构:
import tensorflow as tf
class OverFeatAccurateBase(object):
def __init__(self, input, numclasses, trainmode):
self._numclasses = numclasses
self._trainmode = trainmode
self._logits = self._buildmodel(input)
@property
def numclasses(self):
return self._numclasses
def setmode(self, val):
self._trainmode = val
@property
def mode(self):
return self._trainmode
@property
def logits(self):
return self._logits
def _buildmodel(self, input):
out = tf.layers.conv2d(input, filters=96,
kernel_size=[7, 7],
strides=[2, 2],
padding='valid',
data_format='channels_last',
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='conv1')
out = tf.layers.max_pooling2d(out, pool_size=[3, 3],
strides=[3, 3],
padding='valid',
data_format='channels_last',
name='pool1')
out = tf.layers.conv2d(out, filters=256,
kernel_size=[7, 7],
strides=[1, 1],
padding='valid',
data_format='channels_last',
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='conv2')
out = tf.layers.max_pooling2d(out, pool_size=[2, 2],
strides=[2, 2],
padding='valid',
data_format='channels_last',
name='pool2')
out = tf.layers.conv2d(out, filters=512,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
data_format='channels_last',
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='conv3')
out = tf.layers.conv2d(out, filters=512,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
data_format='channels_last',
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='conv4')
out = tf.layers.conv2d(out, filters=1024,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
data_format='channels_last',
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='conv5')
out = tf.layers.conv2d(out, filters=1024,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
data_format='channels_last',
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='conv6')
out = tf.layers.max_pooling2d(out, pool_size=[3, 3],
strides=[3, 3],
padding='valid',
data_format='channels_last',
name='pool3')
out = tf.layers.flatten(out, name='flatten')
out = tf.layers.dense(out, units=4096, activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='full1'
)
out = tf.cond(tf.equal(self.mode, tf.constant(True)),
lambda: tf.layers.dropout(
out, seed=0), lambda: tf.Print(out,
[out],
'The '
'shape '
'is'))
out = tf.layers.dense(out, units=4096, activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='full2'
)
out = tf.cond(tf.equal(self.mode, tf.constant(True)),
lambda: tf.layers.dropout(
out, seed=0), lambda: out)
logits = tf.layers.dense(out, units=self.numclasses,
activation=tf.nn.relu,
kernel_initializer=tf.initializers.random_normal(
stddev=0.01,
seed=0),
bias_initializer=tf.initializers.constant(0),
kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=10 ** (-5)),
reuse=tf.AUTO_REUSE,
name='output'
)
return logits
现在让我们使用以下代码片段
测试上述结构from networks.overfeataccuratebase import OverFeatAccurateBase
import tensorflow as tf
import numpy as np
inp = np.random.randn(10,221,221,3)
input = tf.placeholder(dtype=tf.float32, shape=(None, 221, 221, 3),
name='input')
mode_train = tf.constant(True)
mode_val = tf.constant(False)
net = OverFeatAccurateBase(input, 1000, mode_train)
logits = net.logits
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
writer = tf.summary.FileWriter('./tboard', graph=sess.graph)
sess.run(init_op)
print(sess.run(logits, feed_dict={input: inp}))
net.setmode(mode_val)
print(sess.run(net.mode))
print(sess.run(logits, feed_dict={input: inp}))
writer.close()
在运行上述代码段时,可以看到,尽管设置了net.setmode(mode_val)
,但图表仍然在培训模式下运行,因为tf.Print
中的tf.cond
语句节点未执行。我错过了什么?
答案 0 :(得分:1)
tf.cond
的工作方式是它运行if语句的两个分支,然后确保正确的分支是将其值分配给输出的那个分支。这就是为什么你看到打印声明出现在你不期望的时候。
看起来tf.cond
语句的目的只是启用或禁用dropout。我在自己的代码中这样做的方法是使丢失概率为占位符,默认值为1.0。然后在训练期间,我输入适当的丢失概率,在验证/测试期间,我保留默认值,并有效地禁用了丢失。