对抗器自动编码器的正则化阶段鉴别器损耗和生成器损耗未收敛

时间:2019-04-08 19:10:06

标签: python machine-learning pytorch autoencoder

我正在尝试使用pytorch为MNIST数据集实现一个简单的对抗自动编码器。只有编码器-解码器(没有正则化阶段)会收敛,并且误差会减小。但是,当我尝试在正则化阶段训练鉴别器和编码器时,发电机损耗在增加,而鉴别器损耗在减少,这与预期的相反。

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.lin1 = nn.Linear(784, 400)
        self.lin2 = nn.Linear(400,100)
        self.lin3 = nn.Linear(100,2)
    def forward(self, x):
        #Without dropout
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return x
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lin1 = nn.Linear(2,10)
        self.lin2 = nn.Linear(10,10)
        self.lin3 = nn.Linear(10,2)
    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return torch.sigmoid(x)
encoder = Encoder().cuda()
discriminator = Discriminator().cuda()
encoder_gen_optimizer = optim.Adam(encoder.parameters(), lr = 3e-4)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=3e-4)
for step in range(500):
    scheduler_disc.step()
    scheduler_gen.step()
    for batch_idx, (data, target) in enumerate(data_loader):
        x = Variable(data.view(data.size(0), -1)).cuda()

        #Regularisation loss - train discriminator to detect fake distribution
        discriminator_optimizer.zero_grad()
        encoder_gen_optimizer.zero_grad()
        z_fake = encoder(x)
        D_fake = discriminator(z_fake)
        target_fake = torch.zeros(batchsize, dtype=torch.int64).cuda()
        target_real = torch.ones(batchsize, dtype=torch.int64).cuda()
        z_real = Variable(torch.randn(batchsize,2)).cuda()
        D_real = discriminator(z_real)
        disc_loss = F.cross_entropy(D_real, target_real) + F.cross_entropy(D_fake, target_fake)
        disc_loss.backward()
        discriminator_optimizer.step()

        #Train generator(encoder) to generate normal distribution
        encoder_gen_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()
        z = encoder(x)
        d = discriminator(z)
        t = torch.ones(batchsize, dtype=torch.int64).cuda()
        gen_loss = F.cross_entropy(d, t)
        gen_loss.backward()
        encoder_gen_optimizer.step()
        generator_loss.append(gen_loss) 
        discriminator_loss.append(disc_loss)

Generator_Loss和Discriminator_Loss: enter image description here

0 个答案:

没有答案