发电机的音乐GAN损失没有减少

时间:2019-08-16 14:09:19

标签: python pytorch generative-adversarial-network

我一直在尝试为音乐生成编写GAN,但是我的训练似乎有问题,因为真实数据的损失和生成器的损失正在增加,而虚假数据的损失几乎是零。我的生成器是一个简单的全连接网络,而鉴别器是一个卷积网络。这是我的训练代码和造成的损失。有人可以帮我解决我做错的事情吗?

def train_model(modelD, modelG, criterion, optimizerD, optimizerG, n_epochs = 5):

random = torch.from_numpy(np.random.rand(100)).float()
measures = torch.from_numpy(notes.reshape(-1, 97, 64)).float()
num_measures = measures.shape[0]
losses = []
# set the model to train mode initially
model.train()
for epoch in range(n_epochs):
    print("Epoch: " + str(epoch+1))
    for measure in measures:
        real_losses = []
        fake_losses = []
        G_losses = []
    #----------------------------------
        # Train discriminator with all-real batch
        inputsD, labelsD = measure, torch.Tensor([1]).long()
        inputsD = inputsD.to(device)
        labelsD = labelsD.to(device)
        optimizerD.zero_grad()

        # forward + backward + optimize
        outputsD = modelD(inputsD)
        lossD_real = criterion(outputsD, labelsD)
        lossD_real.backward(retain_graph=True)
        real_losses.append(lossD_real.detach().numpy())
    #----------------------------------
        # Train discriminator with all-fake batch
        inputsD, labelsD = modelG(random), torch.Tensor([0]).long()
        inputsD = inputsD.to(device)
        labelsD = labelsD.to(device)

        # forward + backward + optimize
        outputsD = modelD(inputsD)
        lossD_fake = criterion(outputsD, labelsD)
        lossD_fake.backward(retain_graph=True)
        fake_losses.append(lossD_fake.detach().numpy())
        optimizerD.step()
    #----------------------------------
        # Train G network
        optimizerG.zero_grad()
        outputs = modelD(inputsD)       

        # forward + backward + optimize
        lossG = criterion(outputs, torch.Tensor([1]).long())
        lossG.backward()
        G_losses.append(lossG.detach().numpy())
        optimizerG.step()

    print("Real Loss: " + str(sum(real_losses)))
    print("Fake Loss: " + str(sum(fake_losses)))
    print("G Loss: " + str(sum(G_losses)))

生成器和鉴别器:

class Generator(torch.nn.Module):
def __init__(self):
    super(Generator, self).__init__()
    self.layers = nn.Sequential(
        nn.Linear(100, 512),
        nn.ReLU(),
        nn.Linear(512, 2048),
        nn.ReLU(),
        nn.Linear(2048, 6208))
def forward(self, x):
    x = self.layers(x)
    x = x.view(97, 64)
    return x

class Discriminator(torch.nn.Module):
def __init__(self):
    super(Discriminator, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2))
    self.layer2 = nn.Sequential(
        nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2))
    self.layer3 = nn.Sequential(
        nn.Linear(12288, 1000),
        nn.ReLU(),
        nn.Linear(1000, 50),
        nn.ReLU(),
        nn.Linear(50, 2),
        nn.LogSoftmax(dim=1))
def forward(self, x):
    x = x.view(-1, 1, 97, 64)
    x = self.layer1(x)
    x = self.layer2(x)
    x = x.view(1, -1)
    x = self.layer3(x)
    return x

设置:

modelD = Discriminator()
modelG = Generator()
criterion = nn.CrossEntropyLoss()
optimizerD = optim.SGD(modelD.parameters(), lr=0.001, momentum=0.9)
optimizerG = optim.SGD(modelG.parameters(), lr=0.1, momentum=0.9)
epochs = 100

输出:

Epoch: 1
Real Loss: 0.6634323596954346
Fake Loss: 0.7299578189849854
G Loss: 0.6627244353294373
Epoch: 2
Real Loss: 0.7140370011329651
Fake Loss: 0.6786684989929199
G Loss: 0.713873028755188
Epoch: 3
Real Loss: 0.7706409692764282
Fake Loss: 0.6278889775276184
G Loss: 0.7697595953941345
Epoch: 4
Real Loss: 0.8299549221992493
Fake Loss: 0.5787992477416992
G Loss: 0.8297436833381653
Epoch: 5
Real Loss: 0.8929369449615479
Fake Loss: 0.5315663814544678
G Loss: 0.8940553069114685
Epoch: 6
Real Loss: 0.9610943794250488
Fake Loss: 0.48426902294158936
G Loss: 0.9678256511688232
Epoch: 7
Real Loss: 1.0396536588668823
Fake Loss: 0.42658042907714844
G Loss: 1.0722308158874512
Epoch: 8
Real Loss: 1.1386523246765137
Fake Loss: 0.3636009395122528
G Loss: 1.2055680751800537
Epoch: 9
Real Loss: 1.2815502882003784
Fake Loss: 0.30338314175605774
G Loss: 1.3635222911834717
Epoch: 10
Real Loss: 1.5050387382507324
Fake Loss: 0.24401997029781342
G Loss: 1.5601530075073242
Epoch: 11
Real Loss: 1.8629376888275146
Fake Loss: 0.1645389348268509
G Loss: 1.9329965114593506
Epoch: 12
Real Loss: 2.417865514755249
Fake Loss: 0.0913706049323082
G Loss: 2.5052433013916016
Epoch: 13
Real Loss: 3.132223606109619
Fake Loss: 0.04399847984313965
G Loss: 3.220311403274536
Epoch: 14
Real Loss: 3.8472046852111816
Fake Loss: 0.021415308117866516
G Loss: 3.9200611114501953
Epoch: 15
Real Loss: 4.434380531311035
Fake Loss: 0.011867986992001534
G Loss: 4.49014949798584
Epoch: 16
Real Loss: 4.873379707336426
Fake Loss: 0.007633917964994907
G Loss: 4.9159369468688965
Epoch: 17
Real Loss: 5.198086738586426
Fake Loss: 0.0055174920707941055
G Loss: 5.230365753173828
Epoch: 18
Real Loss: 5.446741104125977
Fake Loss: 0.004297664389014244
G Loss: 5.473706245422363
Epoch: 19
Real Loss: 5.646566390991211
Fake Loss: 0.0035177513491362333
G Loss: 5.66972541809082
Epoch: 20
Real Loss: 5.814294338226318
Fake Loss: 0.002974849194288254
G Loss: 5.834484100341797

0 个答案:

没有答案