Pytorch 3D卷积网络导致Google Colab中的RAM崩溃

时间:2020-09-17 08:48:50

标签: deep-learning pytorch google-colaboratory ram

我对DL(以及一般而言的编码)非常陌生。我正在pytorch中松散地基于V-Net实现一个非常简单的3D CNN(下面的代码)。目前,我可以使代码在小型测试图像(形状:(1,1,128,128,32))上运行,但是例如,如果我输入更大的值(1,1,256,256,128),该模型将使Colab崩溃。可以预料,或者如果我在这里犯了一些明显的错误,请先谢谢。

class Test3D(nn.Module):


def __init__(self):
    super(Test3D,self).__init__()
    self.input_layer = self._conv_input()
    self.conv_layer1 = self.ResidualBlock(32, 40)
    self.conv_layer2 = self.ResidualBlock(40, 48)
    self.conv_layer3 = self.ResidualBlock(48,56)
    self.conv_layer4 = self.ResidualBlock(56,48)
    self.conv_layer5 = self.ResidualBlock(48,40)
    self.conv_layer6 = self.ResidualBlock(40,32)
    self.conv_layer7 = self._conv_output()
            
  def _conv_input(self):
    conv_layer= nn.Sequential(
       nn.Conv3d(1, 32, kernel_size=(3, 3, 3), stride=2,padding=1) 
    )  
    return conv_layer

  def _conv_output(self):
    conv_layer= nn.Sequential(
  

     
   
      nn.ConvTranspose3d(32, 6, kernel_size=2, stride=2,padding=0)
    )  
    return conv_layer
  
  def ResidualBlock(self,in_c,out_c):
    conv_layer=nn.Sequential(
        nn.Conv3d(in_c,out_c,kernel_size=(3,3,3),padding=1),
        nn.Conv3d(out_c,out_c,kernel_size=(3,3,3),padding=1),
        ContBatchNorm3d(out_c),
        nn.ReLU()

    )  
    return conv_layer

  def forward(self, x):
      
      out = self.input_layer(x)
      print(out.shape)
      out = self.conv_layer1(out)
      print(out.shape)
      out = self.conv_layer2(out)
      print(out.shape)
      out = self.conv_layer3(out)
      print(out.shape)
      out = self.conv_layer4(out)
      print(out.shape)
      out = self.conv_layer5(out)
      print(out.shape)
      out = self.conv_layer6(out)
      print(out.shape)
      out = self.conv_layer7(out) 
      print(out.shape)     

0 个答案:

没有答案