我目前正在使用自定义数据库训练WGAN,该数据库已在MNIST上获得了可用的结果,我们尝试使用自己的图像。我目前正在使用完全相同的代码。 (先对鉴别器进行n次训练,然后对生成器进行训练)。计算损失等。由于填充,我仅对生成器进行了数字更改。 )
正在运行的初始代码是
for epoch in range(epochs):
start = time.time()
disc_loss = 0
gen_loss = 0
for images in train_dataset:
disc_loss += train_discriminator(images)
if disc_optimizer.iterations.numpy() % n_critic == 0:
gen_loss += train_generator()
print('Time for epoch {} is {} sec - gen_loss = {}, disc_loss = {}'.format(epoch + 1, time.time() - start, gen_loss / batch_size, disc_loss / (batch_size*n_critic)))
if epoch % save_interval == 0:
save_imgs(epoch, generator, seed)
由于使用小数运行时出现预期尺寸错误
Expected ndim=4, got ndim=3
指向我们的鉴别器的培训。我通过进行以下更改解决了这些错误:
for images in train_dataset:
images=np.expand_dims(images, axis=0)
images=images/255.
#images=images.resize
disc_loss += train_discriminator(images)
这解决了开始培训过程的问题,但我注意到了其他一些问题。
首先,我们将第一个时期后的损耗值设置为0,并且实际生成的图像在我第一次显示它之后从未改变。最后一个问题是培训时间似乎太快了。纪元2不可能在0.0002秒之内进行训练,依此类推
Time for epoch 1 is 613.6920039653778 sec - gen_loss = -0.7189221382141113, disc_loss = -1.3103094100952148
Time for epoch 2 is 0.00022411346435546875 sec - gen_loss = 0, disc_loss = 0
还有其他人处理过类似的事情吗?还是有我无法理解的循环错误?