张量融合批处理规范与批处理规范

时间:2020-04-09 17:57:24

标签: python tensorflow mean variance batchnorm

在推理期间,我想计算tf.nn.fused_batch_norm中输入的总体均值和方差,并将这些总体均值和方差与移动平均均值和方差进行比较。我以为我可以潜在地使用tf.nn_moments来找到总体均值/方差,但是现在我不确定,因为基于在线阅读帖子,总体均值/方差和矩函数返回的均值/方差不是相同。任何帮助,将不胜感激。这是下面的一些代码:

这些行提取推理期间存储在网络中的移动平均值和方差:

variables = tf.global_variables()
vars_moving_mean = []
vars_moving_variance = []
for var in variables:
    if ("bn/mean" in var.name):
         vars_moving_mean.append(var)

    elif ("bn/variance" in var.name):
         vars_moving_variance.append(var)

我现在要做的是,将数据输入到批处理规范层,并进行自己的计算以找到均值和方差(就像正在训练模型一样),因此我可以将这些值与移动均值进行比较和上面提取的方差。我希望使用以下代码提取这些值:

mean, var = tf.nn.moments(input_to_batch_norm, axes =[0,1,2])

但是如果网络使用的是融合的批处理规范,而不是批处理规范,我仍然可以使用它吗?

0 个答案:

没有答案