Pytorch-tutorial:类定义中的奇怪输入参数

时间:2017-11-07 07:03:38

标签: python pytorch

我正在阅读一些pytorch教程。以下是残差块的定义。但是在forward方法中,每个函数句柄只接受一个参数out,而在__init__函数中,这些函数具有不同数量的输入参数:

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

有谁知道这是如何工作的? 它是标准的python类继承功能还是特定于pytorch?

2 个答案:

答案 0 :(得分:1)

在init函数中定义图层,这意味着参数。在转发功能中,您只需使用init中的预定义设置输入需要处理的数据。 nn.whatever使用您传递给它的设置构建函数。然后这个函数可以在forward中使用,这个函数只接受一个参数。

答案 1 :(得分:0)

您可以在类(__init__函数)的构造函数中定义网络体系结构的不同层。实质上,当您创建不同图层的实例时,可以使用设置参数对其进行初始化。

例如,当您声明第一个卷积层self.conv1时,您将提供初始化图层所需的参数。在前向功能中,您只需使用输入调用图层即可获得相应的输出。例如,在out = self.conv2(out)中,您可以获取上一个图层的输出,并将其作为输入提供给下一个self.conv2图层。

请注意,在初始化期间,您会向图层提供有关向该图层提供何种/形状输入的信息。例如,您告诉第一个卷积层输入中输入和输出通道的数量。在forward方法中,你只需要传递输入,就是它。