在设计网络时应在哪里拼接标准化?例如。如果您有堆叠的Transformer或Attention网络,那么在具有密集层之后的任何时候进行标准化都有意义吗?
答案 0 :(得分:0)
original paper试图解释的是减少过度拟合的使用批量标准化。
在设计网络时应该在哪里拼接标准化?
尽早设置输入的归一化。输入的极值不平衡会导致不稳定。
尽管对输出进行归一化也不会阻止输入再次引起不稳定。
以下是解释BN用途的小代码:
import torch
import torch.nn as nn
m = nn.BatchNorm1d(100, affine=False)
input = 1000*torch.randn(3, 100)
print(input)
output = m(input)
print(output)
print(output.mean()) # should be ~ 0
print(output.std()) # should be ~ 1
在具有密集层之后的任何时间进行归一化有意义吗
是的,您可以这样做,因为矩阵乘法可能导致产生极值。同样,在卷积层之后,由于它们也是矩阵乘法,因此与密集(nn.Linear
)层相比,相似但强度较低。例如,如果您打印重新发送的模型,您将看到每次在conv层之后设置批次规范,如下所示:
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
要打印完整的resnet,可以使用以下方法:
import torchvision.models as models
r = models.resnet18()
print(r)