为什么大型神经网络比小型神经网络传播更快

时间:2018-07-01 14:17:27

标签: performance deep-learning conv-neural-network pytorch

我在pytorch中编写了以下两个NN进行图像分割: 较小的一个:

class ConvNetV0(nn.Module):

def __init__(self):
    super(ConvNetV0, self).__init__()
    self.conv1 = nn.Conv2d(3, 30, 4, padding=2)
    self.conv2 = nn.Conv2d(30, 50, 16, padding=7, bias=True)
    self.conv3 = nn.Conv2d(50, 20, 2, stride=2)
    self.conv4 = nn.Conv2d(20, 2, 2, stride=2)

def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = F.relu(x)
    y = self.conv4(x)
    return y

较大的一个:

class ConvNetV1(nn.Module):

def __init__(self):
    super(ConvNetV1, self).__init__()
    self.conv0 = nn.Conv2d(3, 50, 4, padding=1, stride=2)
    self.conv_r1 = nn.Conv2d(50, 40, 15, padding=7, bias=True)
    self.conv_r2 = nn.Conv2d(40, 25, 3, padding=1)
    self.conv_r3 = nn.Conv2d(25, 25, 2, stride=2)
    # self.conv_r3 = nn.MaxPool2d(2, stride=2)
    self.conv_b1 = nn.Conv2d(50, 15, 4, padding=1, stride=2)
    self.conv1 = nn.Conv2d(40, 2, 1)

def forward(self, x):
    x = self.conv0(x)
    x = F.relu(x)

    x1 = self.conv_r1(x)
    x1 = F.relu(x1)
    x1 = self.conv_r2(x1)
    x1 = F.relu(x1)
    x1 = self.conv_r3(x1)

    x2 = self.conv_b1(x)

    y = torch.cat([x1, x2], dim=1)
    y = self.conv1(y)
    return y

但是在小批量= 8的训练过程中,较小的网络需要2秒才能完成一次迭代,而较大的网络仅需要0.3秒即可完成一次迭代。

我还观察到两个网之间的参数之比为5:6。但是,在训练期间,较小的网络仅占用1GB VRAM,而较大的网络则占用3GB。由于我的1050ti具有4GB VRAM。我想以内存换取速度。知道我该怎么做吗?

1 个答案:

答案 0 :(得分:0)

我根据您指定大小的综合数据对您的模型进行了快速基准测试。至少在我的系统上,差异实际上不是由模型向前或向后给出的,而是由损耗的计算得出的。这可能是由于第一个模型使用了更多的GPU,因此是the queuing of the operations is slightly longer

事实上,您的第一个模型执行了约6.3亿次操作,而第二个模型则执行了约2.7亿次操作。请注意,在第二个模型中,您立即将要素映射的大小从256x256减小到128x128,而在第一个模型中,仅在最后两个卷积中减小了尺寸。这对执行的操作数量有很大影响。

因此,如果您希望使用类似V0的模型并使之更快,则应尝试立即减小要素图的大小。使用这种较小的模型(就内存而言),您还可以增加批处理的大小。

如果您想改用V1,则无能为力。您可以尝试使用Pytorch 0.4中引入的checkpoint来将内存的计算量交换到可以将批处理增加到16的程度。它可能会运行得更快,也可能不会更快,这取决于您需要多少计算量。权衡。

如果输入大小不变,您可以执行另一种简单的操作来使其运行更快,将其设置为torch.cudnn.benchmark = True。这将为特定配置寻找最快的算法集。