[python]Colab在训练gan(pytorch)时不明原因崩溃

时间:2021-04-14 17:00:31

标签: python pytorch generative-adversarial-network

我尝试在一些猴子图片上训练 gan,但如果尝试训练它,它会因未知原因导致 colab 崩溃。 我使用的是 1370 张 128*128 的猴子图片。

我不知道问题出在哪里,请回复

顺便说一句,运行时是 gpu,所以问题与此无关

from torch import optim
import torchvision
from torchvision import transforms
import torch, torch.nn as nn

batch_size = 4

generic_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize((0., 0., 0.), (6, 6, 6)),
    transforms.Grayscale(),
])

trainset=torchvision.datasets.ImageFolder(root='drive/My Drive/monkeys', transform=generic_transform)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=batch_size, shuffle=True)

def _init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight, 0.0, 0.02)

def gen_noise(noise_shape, n_samples, device='cuda:0'):
    return torch.randn(noise_shape, n_samples).to(device)

class Discriminator(nn.Module):
 #convolutional discriminator
  def __init__(self) -> None:
    super(Discriminator, self).__init__()
    
    self.hidden_dim = 64

    self.relu = nn.ReLU(inplace=False)
    self.sigmoid = nn.Sigmoid()

    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=32, stride = 2)
    self.maxPooling_1 = nn.MaxPool2d(kernel_size=3)
    self.conv_2 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=8, stride = 2)
    self.maxPooling_2 = nn.MaxPool2d(kernel_size=2)
    self.linear_layer = nn.Linear(in_features=self.hidden_dim, out_features=1)

  def forward(self, x) -> float:
    self.x = x
    self.x = self.relu(self.conv_1(self.x))
    self.x = self.maxPooling_1(self.x)
    self.x = self.relu(self.conv_2(self.x))
    self.x = self.maxPooling_2(self.x)
    print(self.x.shape)
    self.x = self.x.view(self.x.shape[0],
                         self.x.shape[1]*self.x.shape[2]*self.x.shape[3])
    self.x = self.sigmoid(self.linear_layer(self.x))

    return self.x



 class Generator(nn.Module):
    #fully connected generator
    def __init__(self, hidden_dim, output_dim, z_dim) -> None:
    super(Generator, self).__init__()
    
    self.relu = nn.ReLU(inplace=False)
    self.hidden_dim = hidden_dim
    self.output_dim = output_dim
    self.z_dim = z_dim

    self.linear_layer_1 = nn.Linear(in_features=self.z_dim, out_features=self.hidden_dim)
    self.linear_layer_2 = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim*2)
    self.linear_layer_3 = nn.Linear(in_features=self.hidden_dim*2, out_features=self.output_dim)

    def forward(self, x) -> torch.tensor:
      self.x = x
      self.x = self.relu(self.linear_layer_1(self.x))
      self.x = self.relu(self.linear_layer_2(self.x))
      self.x = self.relu(self.linear_layer_3(self.x))
      return self.x

class GAN():
  def __init__(self, hidden_dim, output_dim, z_dim, criterion, device="cuda:0") -> None:
    if device == "cuda:0":
      assert torch.cuda.is_available(), "apply gpu"

    self.hidden_dim = hidden_dim
    self.output_dim = output_dim
    self.device = device
    self.criterion = criterion
    self.z_dim = torch.tensor(z_dim).long()
    
    self.discriminator = Discriminator().to(self.device)
    self.d_opt = optim.Adam(self.discriminator.parameters(), lr=0.0001)

    self.generator = Generator(hidden_dim=self.hidden_dim, output_dim=self.output_dim, z_dim=self.z_dim).to(self.device)
    self.g_opt = optim.Adam(self.generator.parameters(), lr=0.0001)

    self.generator = self.generator.apply(_init_weights)
    self.discriminator = self.discriminator.apply(_init_weights)

class GAN_Trainer():
  def __init__(self, z_dim, model, device="cuda:0") -> None:
    self.device = device
    self.gan = model
    self.z_dim = z_dim

    self._d_mean_train_loss = None
    self._g_mean_train_loss = None
    
  def train(self, batch) -> None:
    print(1)
    self.batch = batch.to(self.device)
          
    self.noise = gen_noise(self.batch.shape[0],self.z_dim).to(self.device)
    self.gan.g_opt.zero_grad()
    self._g_output = self.gan.generator.forward(self.noise.to(self.device))

    self._g_output = self._g_output.view(self.batch.shape[0], 
                                        1, 
                                        torch.sqrt(torch.tensor(self.gan.output_dim)).int(), 
                                        torch.sqrt(torch.tensor(self.gan.output_dim)).int())
    
    print(self._g_output.shape)

    self._d_for_g_pred = self.gan.discriminator.forward(self._g_output) 
    self._g_loss = self.gan.criterion(self._d_for_g_pred, torch.zeros_like(self._d_for_g_pred))
    self._g_loss.backward()
    self.gan.g_opt.step()

    self.gan.d_opt.zero_grad()
    self._d_fake_pred = self.gan.discriminator.forward(self._g_output)
    self._d_fake_loss = self.gan.criterion(self._g_output, torch.zeros_like(self._g_output))

    self._d_real_pred = self.gan.discriminator.forward(self.batch)
    self._d_real_loss = self.gan.criterion(self.batch, torch.ones_like(self.batch))

    self._d_mean_loss = torch.mean(torch.cat((self._d_fake_loss, self._d_real_loss),0))
    self._d_mean_loss.backward(retain_graph=True)
    self.gan.d_opt.step()
    
    self._d_mean_train_loss = self._d_mean_train_loss + self._d_mean_loss.detach()
    self._g_mean_train_loss = self._g_mean_train_loss + self._g_loss.detach()

torch.cuda.empty_cache()

gan = GAN(hidden_dim=1200,
          output_dim=16384, 
          z_dim = 1000, 
          criterion=nn.BCEWithLogitsLoss())

trainer = GAN_Trainer(model=gan, z_dim=1000)

#here is where it crashes
from tqdm import trange
torch.cuda.empty_cache()

image = trainset[0][0].to("cuda:0").view(1, 
                                         trainset[0][0].shape[0],
                                         trainset[0][0].shape[1],
                                         trainset[0][0].shape[2])
trainer.train(batch=image)

请帮忙!我已经开始失去理智了,谢谢! <3

1 个答案:

答案 0 :(得分:0)

我对您的代码进行了一些调试,发现崩溃发生在一行:

<ZipInfo filename='fake_xml.xml' compress_type=deflate external_attr=0x20 file_size=122 compress_size=114>

我试图找出它崩溃的原因,看起来你的操作没有正确完成,一些就地操作正在改变图形并导致 pytorch 出现故障。

在 GAN 逻辑方面,您需要更改代码中的一些主要问题。这里:

self._d_mean_loss = torch.mean(torch.cat((self._d_fake_loss, self._d_real_loss),0))

真正的损失应该基于判别器 self._d_real_pred 的输出。一般来说,你应该有三个向后调用:

  1. 您将真实批次提供给鉴别器,并期望它输出“真实”类。
  2. 使用生成器生成假图像。将其提供给鉴别器并期望它输出“假”类。
  3. 最后,对于生成器,您再次向鉴别器提供假图像并预计它会失败,针对“假”类进行优化,以便生成器可以学习输出更​​好看的假图像。

我强烈推荐本教程:https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html。这是一个很好的教程,可以帮助您了解 GAN 是如何训练的。

最后,我将删除所有 self.... 存储在 forward 和 train 函数中。任何操作都完成了张量通过将张量存储为成员来更新图,修改它们可能会导致问题并导致梯度擦除等。