在生殖对抗网络(GAN)中无法获得正确的输出

时间:2019-02-19 15:28:31

标签: neural-network artificial-intelligence pytorch generative-adversarial-network

嘿,我做了GAN的Pytorch实现,它似乎可以运行,但是即使经过长时间的培训,生成的图像也没有改善。模型太简单了吗?还是因为学习速度不正确。我的批次大小为50,lr为0.001。请问我是否需要更多详细信息。 Here is the Google Colab link

图片

Here is the pic after almost 50 epochsthe pic after almost 10 epochs我几乎看不到任何改善。 算法有问题吗?

模型

class Generator(nn.Module):

def __init__(self):

    super().__init__()

    self.fc1 = nn.Linear(100,128)
    self.fc2 = nn.Linear(128,256)
    self.fc3 = nn.Linear(256,512)
    self.fc4 = nn.Linear(512,1024)
    self.fc5 = nn.Linear(1024,784)

    self.bn2 = nn.BatchNorm1d(256)
    self.bn3 = nn.BatchNorm1d(512)
    self.bn4 = nn.BatchNorm1d(1024)

def forward(self,x):

    out = F.leaky_relu(self.fc1(x))
    out = self.bn2(F.leaky_relu(self.fc2(out)))
    out = self.bn3(F.leaky_relu(self.fc3(out)))
    out = self.bn4(F.leaky_relu(self.fc4(out)))
    out = torch.tanh(self.fc5(out))
    out = out.view((x.shape[0],-1))

    return out

class Discriminator(nn.Module):

def __init__(self):

    super().__init__()

    self.fc1 = nn.Linear(784,512)
    self.fc4 = nn.Linear(512,512)
    self.fc2 = nn.Linear(512,256)
    self.fc3 = nn.Linear(256,1)

def forward(self, x):

    x = x.reshape((x.shape[0] ,-1))
    out = F.leaky_relu(self.fc1(x))
    out = F.leaky_relu(self.fc4(out))
    out = F.leaky_relu(self.fc2(out))
    out = torch.sigmoid(self.fc3(out))

    return out

培训初始化

adv_loss = nn.BCELoss()

gen = Generator()

Dis = Discriminator()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
gen.to(device)
Dis.to(device)

gen_optimizer = optim.Adam(gen.parameters(),lr = 0.001 )
Dis_optimizer = optim.Adam(gen.parameters(),lr = 0.001 )

epochs = 200

real_labels = torch.ones((50,1) , dtype = torch.float , requires_grad =       False)
fake_labels = torch.zeros((50,1) , dtype = torch.float , requires_grad =     False)

real_labels = real_labels.to(device);
fake_labels = fake_labels.to(device);

print(gen, Dis)

os.makedirs('images', exist_ok=True)

训练机构

for e in range(epochs):

for i,(img,_) in enumerate(train_loader):

    img = img.to(device)

    gen_optimizer.zero_grad() # reset gradients to zero

    z = (torch.tensor(np.random.normal(0., 1.,(50,100)) , dtype = torch.float , device = device ))

    gen_imgs = gen(z)
    #print(gen_imgs.shape,"  ",real_labels.shape)
    #its objective is to fool so the loss is compared with ones 
    g_loss = adv_loss(Dis(gen_imgs) , fake_labels)

    g_loss.backward()

    gen_optimizer.step()

    #============end of generator==================

    Dis_optimizer.zero_grad()


    real_out = Dis(img)

    #fake_out = Dis(gen_imgs)
    #teaching the dis to tell the difference between real and fake        
    real_loss = adv_loss(real_out, real_labels)

    fake_loss = adv_loss(Dis(gen_imgs.detach()) , fake_labels)

    dis_loss = (real_loss + fake_loss)/2.0

    dis_loss.backward()

    Dis_optimizer.step()

    #print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (e, epochs, i, len(train_loader),
    #                                                    dis_loss.item(), g_loss.item()))

    batches_done = e * len(train_loader) + i
    if batches_done % 500 == 0:

        save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=5, normalize=True)
        im = gen_imgs.reshape(-1,28,28)
        im = im.to('cpu')
        plt.figure()
        plt.imshow(im[1].detach().numpy())

0 个答案:

没有答案