我正在尝试使用pytorch为MNIST数据集实现一个简单的对抗自动编码器。只有编码器-解码器(没有正则化阶段)会收敛,并且误差会减小。但是,当我尝试在正则化阶段训练鉴别器和编码器时,发电机损耗在增加,而鉴别器损耗在减少,这与预期的相反。
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.lin1 = nn.Linear(784, 400)
self.lin2 = nn.Linear(400,100)
self.lin3 = nn.Linear(100,2)
def forward(self, x):
#Without dropout
x = self.lin1(x)
x = F.relu(x)
x = self.lin2(x)
x = F.relu(x)
x = self.lin3(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.lin1 = nn.Linear(2,10)
self.lin2 = nn.Linear(10,10)
self.lin3 = nn.Linear(10,2)
def forward(self, x):
x = self.lin1(x)
x = F.relu(x)
x = self.lin2(x)
x = F.relu(x)
x = self.lin3(x)
return torch.sigmoid(x)
encoder = Encoder().cuda()
discriminator = Discriminator().cuda()
encoder_gen_optimizer = optim.Adam(encoder.parameters(), lr = 3e-4)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=3e-4)
for step in range(500):
scheduler_disc.step()
scheduler_gen.step()
for batch_idx, (data, target) in enumerate(data_loader):
x = Variable(data.view(data.size(0), -1)).cuda()
#Regularisation loss - train discriminator to detect fake distribution
discriminator_optimizer.zero_grad()
encoder_gen_optimizer.zero_grad()
z_fake = encoder(x)
D_fake = discriminator(z_fake)
target_fake = torch.zeros(batchsize, dtype=torch.int64).cuda()
target_real = torch.ones(batchsize, dtype=torch.int64).cuda()
z_real = Variable(torch.randn(batchsize,2)).cuda()
D_real = discriminator(z_real)
disc_loss = F.cross_entropy(D_real, target_real) + F.cross_entropy(D_fake, target_fake)
disc_loss.backward()
discriminator_optimizer.step()
#Train generator(encoder) to generate normal distribution
encoder_gen_optimizer.zero_grad()
discriminator_optimizer.zero_grad()
z = encoder(x)
d = discriminator(z)
t = torch.ones(batchsize, dtype=torch.int64).cuda()
gen_loss = F.cross_entropy(d, t)
gen_loss.backward()
encoder_gen_optimizer.step()
generator_loss.append(gen_loss)
discriminator_loss.append(disc_loss)