我一直在尝试为音乐生成编写GAN,但是我的训练似乎有问题,因为真实数据的损失和生成器的损失正在增加,而虚假数据的损失几乎是零。我的生成器是一个简单的全连接网络,而鉴别器是一个卷积网络。这是我的训练代码和造成的损失。有人可以帮我解决我做错的事情吗?
def train_model(modelD, modelG, criterion, optimizerD, optimizerG, n_epochs = 5):
random = torch.from_numpy(np.random.rand(100)).float()
measures = torch.from_numpy(notes.reshape(-1, 97, 64)).float()
num_measures = measures.shape[0]
losses = []
# set the model to train mode initially
model.train()
for epoch in range(n_epochs):
print("Epoch: " + str(epoch+1))
for measure in measures:
real_losses = []
fake_losses = []
G_losses = []
#----------------------------------
# Train discriminator with all-real batch
inputsD, labelsD = measure, torch.Tensor([1]).long()
inputsD = inputsD.to(device)
labelsD = labelsD.to(device)
optimizerD.zero_grad()
# forward + backward + optimize
outputsD = modelD(inputsD)
lossD_real = criterion(outputsD, labelsD)
lossD_real.backward(retain_graph=True)
real_losses.append(lossD_real.detach().numpy())
#----------------------------------
# Train discriminator with all-fake batch
inputsD, labelsD = modelG(random), torch.Tensor([0]).long()
inputsD = inputsD.to(device)
labelsD = labelsD.to(device)
# forward + backward + optimize
outputsD = modelD(inputsD)
lossD_fake = criterion(outputsD, labelsD)
lossD_fake.backward(retain_graph=True)
fake_losses.append(lossD_fake.detach().numpy())
optimizerD.step()
#----------------------------------
# Train G network
optimizerG.zero_grad()
outputs = modelD(inputsD)
# forward + backward + optimize
lossG = criterion(outputs, torch.Tensor([1]).long())
lossG.backward()
G_losses.append(lossG.detach().numpy())
optimizerG.step()
print("Real Loss: " + str(sum(real_losses)))
print("Fake Loss: " + str(sum(fake_losses)))
print("G Loss: " + str(sum(G_losses)))
生成器和鉴别器:
class Generator(torch.nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.layers = nn.Sequential(
nn.Linear(100, 512),
nn.ReLU(),
nn.Linear(512, 2048),
nn.ReLU(),
nn.Linear(2048, 6208))
def forward(self, x):
x = self.layers(x)
x = x.view(97, 64)
return x
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer3 = nn.Sequential(
nn.Linear(12288, 1000),
nn.ReLU(),
nn.Linear(1000, 50),
nn.ReLU(),
nn.Linear(50, 2),
nn.LogSoftmax(dim=1))
def forward(self, x):
x = x.view(-1, 1, 97, 64)
x = self.layer1(x)
x = self.layer2(x)
x = x.view(1, -1)
x = self.layer3(x)
return x
设置:
modelD = Discriminator()
modelG = Generator()
criterion = nn.CrossEntropyLoss()
optimizerD = optim.SGD(modelD.parameters(), lr=0.001, momentum=0.9)
optimizerG = optim.SGD(modelG.parameters(), lr=0.1, momentum=0.9)
epochs = 100
输出:
Epoch: 1
Real Loss: 0.6634323596954346
Fake Loss: 0.7299578189849854
G Loss: 0.6627244353294373
Epoch: 2
Real Loss: 0.7140370011329651
Fake Loss: 0.6786684989929199
G Loss: 0.713873028755188
Epoch: 3
Real Loss: 0.7706409692764282
Fake Loss: 0.6278889775276184
G Loss: 0.7697595953941345
Epoch: 4
Real Loss: 0.8299549221992493
Fake Loss: 0.5787992477416992
G Loss: 0.8297436833381653
Epoch: 5
Real Loss: 0.8929369449615479
Fake Loss: 0.5315663814544678
G Loss: 0.8940553069114685
Epoch: 6
Real Loss: 0.9610943794250488
Fake Loss: 0.48426902294158936
G Loss: 0.9678256511688232
Epoch: 7
Real Loss: 1.0396536588668823
Fake Loss: 0.42658042907714844
G Loss: 1.0722308158874512
Epoch: 8
Real Loss: 1.1386523246765137
Fake Loss: 0.3636009395122528
G Loss: 1.2055680751800537
Epoch: 9
Real Loss: 1.2815502882003784
Fake Loss: 0.30338314175605774
G Loss: 1.3635222911834717
Epoch: 10
Real Loss: 1.5050387382507324
Fake Loss: 0.24401997029781342
G Loss: 1.5601530075073242
Epoch: 11
Real Loss: 1.8629376888275146
Fake Loss: 0.1645389348268509
G Loss: 1.9329965114593506
Epoch: 12
Real Loss: 2.417865514755249
Fake Loss: 0.0913706049323082
G Loss: 2.5052433013916016
Epoch: 13
Real Loss: 3.132223606109619
Fake Loss: 0.04399847984313965
G Loss: 3.220311403274536
Epoch: 14
Real Loss: 3.8472046852111816
Fake Loss: 0.021415308117866516
G Loss: 3.9200611114501953
Epoch: 15
Real Loss: 4.434380531311035
Fake Loss: 0.011867986992001534
G Loss: 4.49014949798584
Epoch: 16
Real Loss: 4.873379707336426
Fake Loss: 0.007633917964994907
G Loss: 4.9159369468688965
Epoch: 17
Real Loss: 5.198086738586426
Fake Loss: 0.0055174920707941055
G Loss: 5.230365753173828
Epoch: 18
Real Loss: 5.446741104125977
Fake Loss: 0.004297664389014244
G Loss: 5.473706245422363
Epoch: 19
Real Loss: 5.646566390991211
Fake Loss: 0.0035177513491362333
G Loss: 5.66972541809082
Epoch: 20
Real Loss: 5.814294338226318
Fake Loss: 0.002974849194288254
G Loss: 5.834484100341797