nn.Linear层在火炬上附加尺寸的应用

时间:2019-01-30 16:00:13

标签: pytorch tensor

pytorch中的全连接层(nn.Linear)如何应用于“其他尺寸”? documentation说,它可以用于将张量(N,*,in_features)连接到(N,*,out_features),其中N在一个示例中的数量是批量的,因此它是无关紧要的,并且*是那些“附加”尺寸。这是否意味着要使用附加维度中的所有可能切片来训练单个图层,还是针对每个切片或其他内容对单独的图层进行训练?

1 个答案:

答案 0 :(得分:2)

in_features * out_features中学习了linear.weight个参数,在out_features中学习了linear.bias个参数。您可以将nn.Linear视为

  1. 将张量重整为某些(N', in_features),其中N'N*描述的所有尺寸的乘积:input_2d = input.reshape(-1, in_features)
  2. 应用标准矩阵-矩阵乘法output_2d = linear.weight @ input_2d
  3. 添加偏差output_2d += linear.bias.reshape(1, in_features)(请注意,我们在所有N'维度上都进行了广播)
  4. 将输出调整为与input相同的尺寸,除了最后一个:output = output_2d.reshape(*input.shape[:-1], out_features)
  5. return output

因此,前导尺寸N*尺寸相同。该文档使N显式地告诉您输入必须至少为 2d,但是可以根据需要输入任意尺寸。