我正在构建一个双重视频鉴别器-GAN,它使用需要合作的时间鉴别器和空间鉴别器。
我已经将类和优化器设置为:
disS_optim = optim.Adam(ds.parameters(),lr=0.0005)
disT_optim = optim.Adam(dt.parameters(),lr=0.0005)
gen_optim = optim.Adam(g.parameters(),lr=0.0001)
loss = nn.BCELoss()
g = GeneratorNet()
ds = SpatialDiscriminatorNet()
dt = TemporalDiscriminatorNet()
培训是:
lossDs_values = []
lossG_values = []
lossDt_values = []
num_epochs = 3
batch_size = 16
for epoch in trange(num_epochs):
for i, (video, label) in enumerate(tqdm(data_loader)):
#video shape ([batch_size, n_frames, 3, 64, 64])
# 1) train discriminators
disS_optim.zero_grad()
disT_optim.zero_grad()
real_data = video.permute(1,0,2,3,4)
fake_data = g(_generator_noise_input(bs=batch_size, seqLen=n_frames)).detach()
fake_data2 = g(_generator_noise_input(bs=batch_size, seqLen=n_frames)).detach()
prediction_real_Ds = ds(real_data)
error_real_Ds = loss(prediction_real_Ds, _ones_target(prediction_real_Ds.shape[0]))
prediction_real_Dt = dt(real_data)
error_real_Dt = loss(prediction_real_Dt, _ones_target(prediction_real_Dt.shape[0]))
prediction_fake_Ds = ds(fake_data)
error_fake_Ds = loss(prediction_fake_Ds, _zeros_target(prediction_fake_Ds.shape[0]))
prediction_fake_Dt = dt(fake_data)
error_fake_Dt = loss(prediction_fake_Dt, _zeros_target(prediction_fake_Dt.shape[0]))
error_Ds = error_real_Ds + error_fake_Ds
error_Dt = error_real_Dt + error_fake_Dt
print(error_Ds, error_Dt)
lossDs_values.append(error_Ds)
lossDt_values.append(error_Dt)
error_Ds.backward()
error_Dt.backward()
disS_optim.step()
disT_optim.step()
# 2) train generator
gen_optim.zero_grad()
prediction_G_Ds = ds(fake_data2)
prediction_G_Dt = dt(fake_data2)
error_G_Ds = loss(prediction_G_Ds, _ones_target(prediction_G_Ds.shape[0]))
error_G_Dt = loss(prediction_G_Dt, _ones_target(prediction_G_Dt.shape[0]))
error_G = error_G_Ds + error_G_Dt
lossG_values.append(error_G)
error_G.backward()
print(error_G)
gen_optim.step()
plt.plot(lossG_values,label='G error')
plt.plot(lossDs_values,label='Ds error')
plt.plot(lossDt_values,label='Dt error')
plt.legend(['G error','Ds error','Dt error'])
plt.show()