批处理归一化层的判别器精度和生成器傻率均接近1.0

时间:2019-07-04 19:34:01

标签: tensorflow keras batch-normalization generative-adversarial-network

我正在keras / tensorflow中建立一个生成对抗网络,以生成狗的图像。我第一次将网络放在一起时,一切都按预期工作。发电机的傻瓜率和鉴别器的准确度应成反比。为了改善网络,我在生成器和鉴别器上添加了几层BatchNormalization。通过设置layer._per_input_updates = {}来训练生成器时,我设法冻结了鉴别器上的batchnorm层。通过比较生成器的每次训练迭代前后的实际权重矩阵,我确认了层在生成器训练期间是冻结的。但是,当我打印出生成器的愚蠢率和鉴别器精度时,它们都收敛到1.0。这仅在使用批处理归一化层时发生。我的问题是:a)这是否意味着网络训练不正确?和b)如果是这样,如何在仍然使用BatchNormalization图层的情况下使其正确训练?

这是我的代码(请忽略许多打印语句和注释,其中大部分用于调试):

import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Activation, BatchNormalization,Conv2D, MaxPooling2D, Dense, Dropout, LeakyReLU,Reshape,Input
from tensorflow.keras.optimizers import Adam
import pickle

allimages = []
for i in range(1,21):
  allimages+=pickle.load(open('ImageFile' + str(i) + '.0','rb'))

def genDogs(num):
  allinds = np.random.randint(0,len(allimages),num)
  finaldata = [allimages[i] for i in allinds]
  return np.array(finaldata)

def getWeights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]
def compareWeights(l1,l2):
    assert len(l1)==len(l2)
    for w1,w2 in zip(l1,l2):
        if not np.array_equal(w1,w2):
            print('blaaaahhhh! No')
            return False
    print('ALL WEIGHTS SAME')
    return True

#run_opts = tf.RunOptions(report_tensor_allocations_upon_oom = True)
#generator
generator = Sequential()
#generator input can be a 100-len vector and output a 4096*3=12288 length vector
generator.add(Dense(512,input_shape=[100],activation='tanh'))
generator.add(Dense(2048))
generator.add(BatchNormalization())
generator.add(LeakyReLU())
generator.add(Dense(4096))
generator.add(BatchNormalization())
generator.add(Activation('tanh'))
generator.add(Dense(12288,activation='sigmoid'))
generator.add(Reshape([64,64,3]))
generator.compile(loss='binary_crossentropy',optimizer=Adam())#,options=run_opts)


#discriminator
discriminator = Sequential()
#input shape [batch,64,64,3]
discriminator.add(Conv2D(64,(4,4),padding='same',input_shape=[64,64,3])) #outputs [None,64,64,64]
discriminator.add(MaxPooling2D()) #new dims [None, 32,32,64]
discriminator.add(BatchNormalization())
discriminator.add(Dropout(.1))
discriminator.add(LeakyReLU())
discriminator.add(Conv2D(128,(2,2),padding='same')) #[None,32,32,128]
discriminator.add(MaxPooling2D()) #[None,16,16,128]
discriminator.add(BatchNormalization())
discriminator.add(Dropout(.1))
discriminator.add(LeakyReLU())
discriminator.add(Conv2D(32,(4,4),padding='valid',activation='tanh')) #shape [None,13,13,32]
discriminator.add(BatchNormalization())
discriminator.add(Reshape([5408]))#5408 size
discriminator.add(Dense(256,activation=LeakyReLU()))
discriminator.add(Dense(1,activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy',optimizer=Adam(),metrics=['accuracy'])#,options=run_opts)

#combined
discriminator.trainable = False
for layer in discriminator.layers:
    layer.trainable = False
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        layer._per_input_updates = {}
gan_input = Input([100])
outs = discriminator(generator(gan_input))
combined = Model(inputs=gan_input,outputs=outs)
combined.compile(loss='binary_crossentropy',metrics=['accuracy'],optimizer=Adam())#,options=run_opts)
print('###########################################loaded models!')
def showIm():
  noise1 = np.random.random([1,100])
  generated = generator.predict_on_batch(noise1)
  plt.imshow(generated[0])

def iteration(halfbatch):

  #train the discriminator once first
  discriminator.trainable=True
  noise1 = np.random.random([halfbatch,100])
  generated = generator.predict_on_batch(noise1)
  dogs = genDogs(halfbatch)
  together = np.concatenate((generated,dogs),axis=0)
  y = np.array([0 for _ in range(halfbatch)] + [1 for _ in range(halfbatch)])
  #print(discriminator.predict(together))
  #print('#######################################fitting discriminator!')
  outs = discriminator.train_on_batch(together,y)
  print('discriminator loss: ' + str(outs[0]) + ', discriminator accuracy: '+ str(outs[1]))

  preW = getWeights(discriminator)
  #train the combined network
  for _ in range(1):
    noise2 = np.random.random([halfbatch,100])
    labels = np.ones([halfbatch])
    #print(combined.predict(noise2))
    #print('#########################################fitting generator!')
    outs2 = combined.train_on_batch(noise2,labels)
    postW = getWeights(discriminator)
    compareWeights(preW,postW)
    print('generator loss: ' + str(outs2[0]) + ', generator fool rate: ' + str(outs2[1]))

for i in range(100):
    iteration(64)
generator.save('generator0.h5')

这是我运行此输出的示例:

generator loss: 1.7851067, generator fool rate: 0.8125
WARNING:tensorflow:Discrepancy between trainable weights and collected trainable weights, did
 you set `model.trainable` without calling `model.compile` after ?
discriminator loss: 1.0960464e-07, discriminator accuracy: 1.0
ALL WEIGHTS SAME
generator loss: 1.630374, generator fool rate: 0.84375
WARNING:tensorflow:Discrepancy between trainable weights and collected trainable weights, did
 you set `model.trainable` without calling `model.compile` after ?
discriminator loss: 1.0960464e-07, discriminator accuracy: 1.0
ALL WEIGHTS SAME

如您所见,在生成器迭代期间,鉴别器权重似乎没有经过训练。但是,发电机的愚蠢率和鉴别器精度都收敛到1.0。当我摆脱了BatchNormalization层时,并没有发生这种情况。

0 个答案:

没有答案