Torch.cat内存爆炸

时间:2019-02-12 08:02:30

标签: python pytorch

我试图用ResNet50上的此模块替换Conv2d。

class SubtractedConv(nn.Module):
    def __init__(self, input_ch, output_ch, kernels, stride=1):
        super().__init__()
        self.point_wise = nn.Conv2d(input_ch, output_ch//2, 1, bias=False, stride=stride)
        self.depth_wise = nn.Conv2d(output_ch // 2, output_ch // 2, kernels, groups=output_ch // 2, bias=False, padding=kernels // 2)
        self.low_pass = nn.Conv2d(output_ch // 2, output_ch // 2, kernels, bias=False, padding=kernels // 2)
    def forward(self, x):
        p = self.point_wise(x)
        d = self.depth_wise(p)
        d -= p 
        l = self.low_pass(p)
        x = torch.cat((d, l), 1)
        return x

预期的输出应该具有与普通Conv2d相同的通道,但是在torch.cat()中我的cuda内存不足。 我想知道为什么?以及如何处理呢?

我使用Pytorch 0.4.0并在Tesla P100上运行,图像尺寸为224 * 224,批处理尺寸为16。

实际上,类似的方法在Keras上有效,并且参数较少(ResNet50中为16M,而普通Conv2D为25M)。

def subtractedconv(input_tensor, kernel_size, filters, stride=1):
    p = kl.Conv2D(filters//2, (1, 1), use_bias=False, strides=stride, padding='same')(input_tensor)
    d = DepthwiseConv2D(kernel_size, use_bias=False, padding='same')(p)
    d = kl.subtract([d, p])
    l = kl.Conv2D(filters//2, kernel_size, use_bias=False, padding='same')(p)
    x = kl.Concatenate(axis=-1)([d, l])
    return x

1 个答案:

答案 0 :(得分:0)

PyTorch的问题很可能是中间张量的创建,而不是torch.cat本身。为了通过nn.Conv2d向后传播,您需要将该操作的输入保留在内存中。遍历各个层时,内存消耗增加(保留所有中间结果)。现在,在您的forward代码中,您拥有其中三个

p = self.point_wise(x) # x is kept
d = self.depth_wise(p) # p is kept
d -= p # here we do not need to keep d, because of derivative formula for subtraction
l = self.low_pass(p)
x = torch.cat((d, l), 1) # but since this goes into further processing, we will need to keep d anyway

请注意,即使您的操作在计算上可能是高效的(例如,内核较小),它们仍然需要相同数量的内存来保存输入要素图-换句话说,您为每个{ {1}},而不管其自身的复杂性如何。因此,很明显,如果将一个nn.Conv2d替换为三个nn.Conv2d,则可以预期内存消耗将增加大约三倍。

但是,您的情况有一个解决方法。由于您的整个操作是线性的(您只执行卷积,它是线性的,减法是线性的,而级联在某种意义上是线性的),因此您可以使用精心准备的内核将整个计算简化为一个卷积。如果我们将卷积视为线性算子,并用point_wise表示P运算,用depth_wise表示D,用low_pass表示L您的前向计算为concatenate(Dx - Px, LPx),可以简化为[concatenate(D-P, LP)]x。因此,您可以基于三组权重(分别对应于point_wisedepth_wiselow_pass)来预先计算内核,然后调用一次nn.functional.conv2d。不过,实现这种预计算非常困难,因为它需要对参数张量的形状进行相当复杂的转换才能保留操作的确切语义(例如,从1x1内核D中减去空间内核P) 。我尝试在10分钟内获取此文件,但失败了,但是如果这非常重要,请考虑在PyTorch论坛上提问或在评论中让我知道。

关于Keras为何处理它,我不确定,但是一个很强的猜测是它要归功于TensorFlow。与PyTorch不同,TensorFlow使用(主要是)静态计算图,可以提前对其进行分析和优化。我希望TensorFlow能够识别三个线性运算符的序列并将它们组合为一个,或者至少部分利用它们的线性来优化内存使用。