在pytorch中使用BatchNorm进行培训

时间:2019-09-10 06:26:04

标签: deep-learning pytorch

我想知道在pytorch中使用BatchNorm进行训练时是否需要做任何特别的事情。根据我的理解,gammabeta参数是通过梯度进行更新的,这通常是由优化程序完成的。但是,批次的均值和方差会使用动量缓慢更新。

  1. 那么,当均值和方差参数更新时,我们是否需要指定优化器,还是pytorch会自动进行处理?
  2. 有没有一种方法可以访问BN层的均值和方差,因此我可以确保在训练模型时该值正在变化。

如果需要,这里是我的模型和训练过程:

def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0.):
    "Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    layers.append(nn.Linear(n_in, n_out))

    return nn.Sequential(*layers)

class Model(nn.Module):
    def __init__(self, i, o, h=()):
        super().__init__()

        nodes = (i,) + h + (o,)
        self.layers = nn.ModuleList([bn_drop_lin(i,o, p=0.5) 
                                     for i, o in zip(nodes[:-1], nodes[1:])])

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))

        return self.layers[-1](x)

培训:

for i, data in enumerate(trainloader):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

1 个答案:

答案 0 :(得分:3)

Batchnorm层的行为会有所不同,具体取决于模型是处于火车模式还是评估模式。

net处于训练模式时(即调用net.train()之后),net中包含的批处理规范层将使用批处理统计信息以及gamma和beta参数来调整均值和方差每个小批量。在火车模式下,运行平均值和方差也将被调整。运行均值和方差的这些更新发生在正向传递过程中(调用net(inputs)时)。仅在调用optimizer.step()后才更新gamma和beta参数。

net处于评估模式(net.eval()时,批次规范将使用训练期间存储的历史运行平均值和运行方差来调整样本的平均值和运行方差。

您可以通过显示层running_meanrunning_var成员来确保批处理规范按预期进行更新,从而检查运行均值和方差的批处理规范层。可以通过分别显示批处理规范层的weightbias成员来访问可学习的gamma和beta参数。

修改

下面是一个简单的演示代码,显示running_mean在转发期间已更新。请注意,优化器未更新它。

>>> import torch
>>> import torch.nn as nn
>>> layer = nn.BatchNorm1d(5)
>>> layer.train()
>>> layer.running_mean
tensor([0., 0., 0., 0., 0.])
>>> result = layer(torch.randn(5,5))
>>> layer.running_mean
tensor([ 0.0271,  0.0152, -0.0403, -0.0703, -0.0056])