tensorflow:批量归一化时获取ValueError(“不支持任何值。”)

时间:2019-04-02 02:55:19

标签: python tensorflow

我正在通过tensorflow使用批处理规范化。问题在于,似乎网络必须在batchnorm层中使用tf.cond运算符,否则它将报告以下错误:

ValueError("None values not supported.")

对此我感到困惑。我在下面发布一个简单的示例:

conv_bias = some_layer_defined_previously
train = tf.cast((True), tf.bool)
conv_bias = self.batch_norm2(conv_bias, filters, train)

此网络将正常运行。但是,如果我将其更改为以下内容:

conv_bias = some_layer_defined_previously
train = True
conv_bias = self.batch_norm3(conv_bias, filters, train)

然后发生错误。

我很困惑,因为两个网络之间的唯一区别是batchnorm层。两个batchnorm层的定义如下:

batchnorm2

def batch_norm2(self, x, size, training, decay=0.999):
    beta = tf.Variable(tf.zeros([size]), name='beta')
    scale = tf.Variable(tf.ones([size]), name='scale')
    pop_mean = tf.Variable(tf.zeros([size]), 'mean')
    pop_var = tf.Variable(tf.ones([size]), 'var')
    epsilon = 1e-3
    self.model_params += [beta, scale, pop_mean, pop_var]
    batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
    train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
    train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))

    def batch_statistics():
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')

    def population_statistics():
        return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')

    return tf.cond(training, batch_statistics, population_statistics)

batchnorm3

def batch_norm3(self, x, size, training, decay=0.999):
    beta = tf.Variable(tf.zeros([size]), name='beta')
    scale = tf.Variable(tf.ones([size]), name='scale')
    pop_mean = tf.Variable(tf.zeros([size]), 'mean')
    pop_var = tf.Variable(tf.ones([size]), 'var')
    epsilon = 1e-3
    batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
    train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
    train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
    with tf.control_dependencies([train_mean, train_var]):
        return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')

这两个定义之间的唯一区别是,我删除了tf.cond中的batchnorm2,并直接使用了Truetf.cond的{​​{1}}分支。在batchnorm3中,实际上没有使用batchnorm3变量,因此无论我设置的类型是train还是train=True,错误都完全相同。

完整的错误消息如下:

train = tf.cast((True), tf.bool)

这是什么原因?谢谢大家的帮助!

0 个答案:

没有答案