我正在尝试使用 MNIST 数据集构建 GAN,当我开始训练它时,我可以看到生成器和鉴别器都卡住了(即损失值不会改变,生成器输出相同白噪声无意义的图像)。我尝试更改鉴别器和生成器的损失函数、学习率、添加和删除层,但没有任何尝试奏效。附上我的代码:
def tensor_to_plt_im(im: torch.Tensor):
return im.permute(1, 2, 0)
def d_loss(discriminator_generated_x, discriminator_true_x):
return -0.5 * torch.mean(torch.log(discriminator_true_x + 1e-8)) \
- 0.5 * torch.mean(torch.log(1 - discriminator_generated_x + 1e-8))
def gen_loss(discriminator_generated_x, discriminator_true_x):
return -1 * d_loss(discriminator_generated_x, discriminator_true_x)
def gen_loss_non_saturating(discriminator_generated_x):
return -0.5 * torch.mean(torch.log(discriminator_generated_x + 1e-8))
def d_loss_least_squares(discriminator_generated_x, discriminator_true_x):
return 0.5 * torch.mean(torch.square(discriminator_true_x + 1e-8 - 1)) \
+ 0.5 * torch.mean(torch.square(discriminator_generated_x + 1e-8))
def gen_loss_least_squares(discriminator_generated_x):
return 0.5 * torch.mean(torch.square(discriminator_generated_x + 1e-8 - 1))
class MNISTGen(nn.Module):
def __init__(self, num_input_channels):
super(MNISTGen, self).__init__()
self.conv = nn.Sequential(
# input is (num_input_channels*4) x 4 x 4
nn.ConvTranspose2d(num_channels * 4, num_channels * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_channels * 2),
# nn.ReLU(True),
# size. (num_channels*2) x 7 x 7
nn.ConvTranspose2d(num_channels * 2, num_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_channels),
# nn.ReLU(True),
# size. (num_channels) x 14 x 14
nn.ConvTranspose2d(num_channels, num_input_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_input_channels),
nn.Sigmoid() # the final output image size. (num_input_channels) x 28 x 28
)
self.linear = nn.Sequential(
# size. (latent_vec_size)
nn.Linear(latent_vec_size, 64),
nn.BatchNorm1d(64),
# nn.LeakyReLU(0.2, inplace=True),
# size. (64)
nn.Linear(64, 256),
nn.BatchNorm1d(256),
# size. 256
nn.Linear(256, 512),
nn.BatchNorm1d(512),
# size. 512
nn.Linear(512, (num_channels * 4) * 4 * 4),
nn.BatchNorm1d((num_channels * 4) * 4 * 4),
)
def forward(self, x):
x = self.linear(x)
x = x.view(-1, num_channels * 4, 4, 4)
return self.conv(x)
class MNISTDisc(nn.Module):
def __init__(self, num_input_channels):
super(MNISTDisc, self).__init__()
self.conv = nn.Sequential(
# input is (num_input_channels) x 28 x 28
nn.Conv2d(num_input_channels, num_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# size. (num_channels) x 14 x 14
nn.Conv2d(num_channels, num_channels * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_channels * 2),
nn.LeakyReLU(0.2, inplace=True),
# size. (num_channels*2) x 7 x 7
nn.Conv2d(num_channels * 2, num_channels * 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_channels * 4),
nn.LeakyReLU(0.2, inplace=True), # size. (num_channels*4) x 4 x 4
)
self.linear = nn.Sequential(
# size. (num_channels_d*4) x 4 x 4
nn.Linear((num_channels * 4) * 4 * 4, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
# size. (512)
nn.Linear(512, 256),
nn.BatchNorm1d(256),
# size. 256
nn.Linear(256, 64),
nn.BatchNorm1d(64),
# size. 64
nn.Linear(64, 1),
nn.Sigmoid() # final classification
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, (num_channels * 4) * 4 * 4) # reshape input so that it fits fc layer
return self.linear(x)
def train(g_net: nn.Module, d_net: nn.Module):
G_losses, D_losses, iters, gen_lst = [], [], 0, []
optimizerD = optim.Adam(d_net.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(g_net.parameters(), lr=lr, betas=(beta1, 0.999))
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# Train with all-real batch
d_net.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
# Forward pass real batch through D
real_output = d_net(real_cpu).view(-1)
# errD_real.backward()
D_x = real_output.mean().item()
# Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, latent_vec_size, device=device)
# Generate fake image batch with G
fake = g_net(noise)
# Classify all fake batch with D
fake_output = d_net(fake.detach()).view(-1)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
loss = d_loss(fake_output, real_output)
loss.backward()
# errD_fake.backward()
D_G_z1 = fake_output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = loss
# Update D
optimizerD.step()
# perform a generator iteration every 'discriminator_iterations' steps of the discriminator
if i % discriminator_iterations == 0:
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
g_net.zero_grad()
# Since we just updated D, perform another forward pass of all-fake batch through D
fake_output = d_net(fake).view(-1)
# Calculate G's loss based on this output
# errG = criterion(fake_output, label)
errG = gen_loss_non_saturating(fake_output)
# Calculate gradients for G
errG.backward()
D_G_z2 = fake_output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
G_losses.append(errG.item())
D_losses.append(errD.item()) # Save Losses for plotting later
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
with torch.no_grad():
fake = g_net(fixed_noise).detach().cpu()
gen_lst.append(vutils.make_grid(fake))
plt.imshow(tensor_to_plt_im(gen_lst[-1]))
plt.show()
iters += 1
if __name__ == '__main__':
dataloader = generate_mnist_data_set()
generator = MNISTGen(num_input_channels=1)
discriminator = MNISTDisc(num_input_channels=1)
generator.apply(weights_init), discriminator.apply(weights_init)
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(image_size, latent_vec_size, device=device)
train(generator, discriminator)
我将不胜感激任何建议、评论和一般指导,因为就像我说的那样,我完全被困住了。
提前致谢