一起使用nn.Linear()和nn.BatchNorm1d()

时间:2019-07-19 14:52:45

标签: pytorch

当数据为3D(批大小,高,宽)时,我不理解BatchNorm1d的工作原理。

示例

  • 输入大小:(2,50,70)
  • 层:nn.Linear(70,20)
  • 输出大小:(2,50,20)

如果我随后包含一个批处理规范化层,则需要num_features = 50:

  • BN:nn.BatchNorm1d(50)

我不明白为什么它不是20:

  • BN:nn.BatchNorm1d(20)

示例1)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bn11 = nn.BatchNorm1d(50)
        self.fc11 = nn.Linear(70,20)

    def forward(self, inputs):
        out = self.fc11(inputs)
        out = torch.relu(self.bn11(out))
        return out

model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)

示例2)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bn11 = nn.BatchNorm1d(20)
        self.fc11 = nn.Linear(70,20)

    def forward(self, inputs):
        out = self.fc11(inputs)
        out = torch.relu(self.bn11(out))
        return out

model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)
  • 示例1起作用。
  • 示例2引发错误:
    • RuntimeError:running_mean应该包含50个元素,而不是20个

2D示例:

  • 输入大小:(2,70)
  • 层:nn.Linear(70,20)
  • BN:nn.BatchNorm1d(20)

我认为BN层中的20个是由于线性层输出了20个节点,每个节点都需要一个运行平均值/标准输入值。

为什么在3D情况下,如果线性层有20个输出节点,那么BN层就没有20个要素?

1 个答案:

答案 0 :(得分:1)

一个人可以在torch.nn.Linear documentation内找到答案。

它取input形状为(N, *, I)并返回(N, *, O),其中I代表输入尺寸,O代表输出尺寸,*之间的任何尺寸。

如果将torch.Tensor(2,50,70)传递到nn.Linear(70,20),将得到形状为(2, 50, 20)的输出,当您使用BatchNorm1d时,它将计算第一个非批量尺寸的移动平均值,因此就是50。这就是您出错的原因。