我正在通过tensorflow使用批处理规范化。问题在于,似乎网络必须在batchnorm层中使用tf.cond
运算符,否则它将报告以下错误:
ValueError("None values not supported.")
对此我感到困惑。我在下面发布一个简单的示例:
conv_bias = some_layer_defined_previously
train = tf.cast((True), tf.bool)
conv_bias = self.batch_norm2(conv_bias, filters, train)
此网络将正常运行。但是,如果我将其更改为以下内容:
conv_bias = some_layer_defined_previously
train = True
conv_bias = self.batch_norm3(conv_bias, filters, train)
然后发生错误。
我很困惑,因为两个网络之间的唯一区别是batchnorm层。两个batchnorm层的定义如下:
batchnorm2
def batch_norm2(self, x, size, training, decay=0.999):
beta = tf.Variable(tf.zeros([size]), name='beta')
scale = tf.Variable(tf.ones([size]), name='scale')
pop_mean = tf.Variable(tf.zeros([size]), 'mean')
pop_var = tf.Variable(tf.ones([size]), 'var')
epsilon = 1e-3
self.model_params += [beta, scale, pop_mean, pop_var]
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
def batch_statistics():
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')
return tf.cond(training, batch_statistics, population_statistics)
batchnorm3
def batch_norm3(self, x, size, training, decay=0.999):
beta = tf.Variable(tf.zeros([size]), name='beta')
scale = tf.Variable(tf.ones([size]), name='scale')
pop_mean = tf.Variable(tf.zeros([size]), 'mean')
pop_var = tf.Variable(tf.ones([size]), 'var')
epsilon = 1e-3
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
这两个定义之间的唯一区别是,我删除了tf.cond
中的batchnorm2
,并直接使用了True
中tf.cond
的{{1}}分支。在batchnorm3
中,实际上没有使用batchnorm3
变量,因此无论我设置的类型是train
还是train=True
,错误都完全相同。
完整的错误消息如下:
train = tf.cast((True), tf.bool)
这是什么原因?谢谢大家的帮助!