与Pytorch中的BatchNorm相比,GroupNorm相当慢,并且消耗的GPU内存更高

时间:2019-09-19 01:25:08

标签: pytorch

我在pytorch中使用GroupNorm而不是BatchNorm,并保持所有其他(网络体系结构)不变。它表明在Imagenet数据集中,使用resnet50架构,GroupNorm比BatchNorm慢40%,并且比BatchNorm多消耗33%的GPU内存。我真的很困惑,因为GroupNorm不需要比BatchNorm更多的计算。详细信息如下。

有关“组归一化”的详细信息,请参见本文:https://arxiv.org/pdf/1803.08494.pdf

对于BatchNorm,一个小批量消耗12.8秒,GPU内存为7.51GB;

对于GroupNorm,一个小批量消耗17.9秒,GPU内存为10.02GB。

我使用以下代码将所有BatchNorm层转换为GroupNorm层。

def convert_bn_model_to_gn(module, num_groups=16):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with :class:`torch.nn.GroupNorm`.
Args:
    module: your network module
    num_groups: num_groups of GN
"""
mod = module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
    mod = nn.GroupNorm(num_groups, module.num_features,
                       eps=module.eps, affine=module.affine)
    # mod = nn.modules.linear.Identity()
    if module.affine:
        mod.weight.data = module.weight.data.clone().detach()
        mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
    mod.add_module(name, convert_bn_model_to_gn(
        child, num_groups=num_groups))
del module
return mod

1 个答案:

答案 0 :(得分:1)

是的,没错,与BN相比,GN确实使用了更多资源。我猜这是因为它必须计算每组通道的均值和方差,而BN只需在整个批次中计算一次。

但是GN的优点是,您可以将批处理大小降低到2,而不会降低本文所述的任何性能,因此您可以弥补开销计算。