当数据为3D(批大小,高,宽)时,我不理解BatchNorm1d的工作原理。
示例
如果我随后包含一个批处理规范化层,则需要num_features = 50:
我不明白为什么它不是20:
示例1)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bn11 = nn.BatchNorm1d(50)
self.fc11 = nn.Linear(70,20)
def forward(self, inputs):
out = self.fc11(inputs)
out = torch.relu(self.bn11(out))
return out
model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)
示例2)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bn11 = nn.BatchNorm1d(20)
self.fc11 = nn.Linear(70,20)
def forward(self, inputs):
out = self.fc11(inputs)
out = torch.relu(self.bn11(out))
return out
model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)
2D示例:
我认为BN层中的20个是由于线性层输出了20个节点,每个节点都需要一个运行平均值/标准输入值。
为什么在3D情况下,如果线性层有20个输出节点,那么BN层就没有20个要素?
答案 0 :(得分:1)
一个人可以在torch.nn.Linear
documentation内找到答案。
它取input
形状为(N, *, I)
并返回(N, *, O)
,其中I
代表输入尺寸,O
代表输出尺寸,*
之间的任何尺寸。
如果将torch.Tensor(2,50,70)
传递到nn.Linear(70,20)
,将得到形状为(2, 50, 20)
的输出,当您使用BatchNorm1d
时,它将计算第一个非批量尺寸的移动平均值,因此就是50。这就是您出错的原因。