哪个维度的批量归一化?

时间:2021-05-29 18:04:45

标签: pytorch batch-normalization

我们在哪个维度上计算均值和标准差?是在神经网络层的隐藏维度上,还是在批次中的所有样本上分别针对每个隐藏维度?

论文中说我们对批次进行标准化。

torch.nn.BatchNorm1d 中,输入参数是 num_features,这对我来说没有意义。

为什么 pytorch 不遵循 Batchnormalization 的原始论文?

1 个答案:

答案 0 :(得分:0)

<块引用>

我们在哪个维度上计算均值和标准差?

在第 0 个维度上,对于形状 1D(batch, num_features) 输入,它将是:

batch = 64
features = 12
data = torch.randn(batch, features)

mean = torch.mean(data, dim=0)
var = torch.var(data, dim=0)
<块引用>

在 torch.nn.BatchNorm1d 中,输入参数是“num_features”, 这对我来说毫无意义。

它与归一化无关,而是通过meanvar可学习参数对gammabeta进行重新参数化。来自论文:

batchnorm

缩放和移位阶段使用的参数都是 num_features 形状,因此我们必须传递这个值才能用特定的形状初始化它们。

以下是一个从头开始的实现示例,供参考:

class BatchNorm1d(torch.nn.Module):
    def __init__(self, num_features, momentum: float = 0.9, eps: float = 1e-7):
        super().__init__()
        self.num_features = num_features

        self.gamma = torch.nn.Parameter(torch.ones(1, self.num_features))
        self.beta = torch.nn.Parameter(torch.zeros(1, self.num_features))
        
        self.register_buffer("running_mean", torch.ones(1, self.num_features))
        self.register_buffer("running_var", torch.ones(1, self.num_features))

        self.momentum = momentum
        self.eps = eps

    def forward(self, X):
        if not self.training:
            X_hat = X - self.running_mean / torch.sqrt(self.running_var + self.eps)
        else:
            mean = X.mean(dim=0).unsqueeze(dim=0)
            var = ((X - mean) ** 2).mean(dim=0).unsqueeze(dim=0)

            # Update running mean and variance
            self.running_mean *= self.momentum
            self.running_mean += (1 - self.momentum) * mean

            self.running_var *= self.momentum
            self.running_var += (1 - self.momentum) * var

            X_hat = X - mean / torch.sqrt(var + self.eps)

        return X_hat * self.gamma + self.beta
<块引用>

为什么 pytorch 不遵循 Batchnormalization 的原始论文?

一目了然

相关问题