GAN的培训无效,损失值不变

时间:2020-08-10 11:35:32

标签: python machine-learning computer-vision pytorch generative-adversarial-network

我正在构建一个双重视频鉴别器-GAN,它使用需要合作的时间鉴别器和空间鉴别器。

我已经将类和优化器设置为:

disS_optim = optim.Adam(ds.parameters(),lr=0.0005)
disT_optim = optim.Adam(dt.parameters(),lr=0.0005)
gen_optim = optim.Adam(g.parameters(),lr=0.0001)
loss = nn.BCELoss()
g = GeneratorNet()
ds = SpatialDiscriminatorNet()
dt = TemporalDiscriminatorNet()

培训是:

lossDs_values = []
lossG_values = []
lossDt_values = []
num_epochs = 3
batch_size = 16
for epoch in trange(num_epochs):
    for i, (video, label) in enumerate(tqdm(data_loader)):
        #video shape ([batch_size, n_frames, 3, 64, 64])
       

        # 1) train discriminators
        disS_optim.zero_grad()
        disT_optim.zero_grad()

        real_data = video.permute(1,0,2,3,4)
        fake_data = g(_generator_noise_input(bs=batch_size, seqLen=n_frames)).detach()
        fake_data2 = g(_generator_noise_input(bs=batch_size, seqLen=n_frames)).detach()

        prediction_real_Ds = ds(real_data)
        error_real_Ds = loss(prediction_real_Ds, _ones_target(prediction_real_Ds.shape[0]))
        
        prediction_real_Dt = dt(real_data)
        error_real_Dt = loss(prediction_real_Dt, _ones_target(prediction_real_Dt.shape[0]))

        prediction_fake_Ds = ds(fake_data)
        error_fake_Ds = loss(prediction_fake_Ds, _zeros_target(prediction_fake_Ds.shape[0]))
        
        prediction_fake_Dt = dt(fake_data)
        error_fake_Dt = loss(prediction_fake_Dt, _zeros_target(prediction_fake_Dt.shape[0]))
        
        error_Ds = error_real_Ds + error_fake_Ds
        error_Dt = error_real_Dt + error_fake_Dt
        print(error_Ds, error_Dt)
 
        lossDs_values.append(error_Ds)
        lossDt_values.append(error_Dt)
        error_Ds.backward()
        error_Dt.backward()

        disS_optim.step()
        disT_optim.step()


        # 2) train generator
        gen_optim.zero_grad()

        prediction_G_Ds = ds(fake_data2)
        prediction_G_Dt = dt(fake_data2)
        
        error_G_Ds = loss(prediction_G_Ds, _ones_target(prediction_G_Ds.shape[0]))
        error_G_Dt = loss(prediction_G_Dt, _ones_target(prediction_G_Dt.shape[0]))
        error_G = error_G_Ds + error_G_Dt
        lossG_values.append(error_G)
        error_G.backward()
        print(error_G)
        
        gen_optim.step()
        
plt.plot(lossG_values,label='G error')
plt.plot(lossDs_values,label='Ds error')
plt.plot(lossDt_values,label='Dt error')
plt.legend(['G error','Ds error','Dt error'])
plt.show()

问题是图形的值基本上保持不变,或者它们变化很小 enter image description here

0 个答案:

没有答案