使用预训练鉴别器进行发电机训练

时间:2018-11-19 07:34:20

标签: python-3.x tensorflow keras deep-learning generative-adversarial-network

我有一个受过训练的鉴别器,可以以 97%的准确度区分真实 伪造图像。现在,我想训练一个生成器,该生成器强制将真实图像重新生成为伪图像。可以说,鉴别器给出1代表真实图像,给出0代表虚假图像。基于鉴别器损耗,我想训练发生器以产生所需的输出。谁能帮助我,我做对了还是错了?

我的代码结构适用于GANS

def construct_discriminator(X,Y):
  discriminator=trained_model.evaluate(X,Y,verbose=0)
  return discriminator

# Creates the generator model. This model has an input of random noise and generates an image that will try mislead the discriminator.
def construct_generator(image_shape):
    generator = Sequential()
    # First Downscaling The Real Image and trying to produce a Fake Image
    Downscale Layer 1
    Downscale Layer 2
    Downscale Layer 3
    #Upscaling
    Upscale Layer 1
    Upscale Layer 2
    Upscale Layer 3
    optimizer = Adam(lr=0.00015, beta_1=0.5)
    generator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=None)
    generator.summary()
    return generator

def train_dcgan(batch_size, epochs, image_shape, labels):
    # Build the adversarial model that consists in the generator output 
    connected to the discriminator
    generator = construct_generator(image_shape)
    discriminator=construct_discriminator(image_shape, labels)
    gan = Sequential()
    # Only false for the adversarial model
    discriminator.trainable = False
    gan.add(generator)
    gan.add(discriminator)

    number_of_batches = int(len(Fake_Images)/ batch_size)

    # Let's train the DCGAN for n epochs
    for epoch in range(epochs):
      print("Epoch " + str(epoch+1) + "/" + str(epochs) + " :")
      for batch_number in range(number_of_batches):
        start_time = time.time()
        # The last batch is smaller than the other ones, so we need to take that into account
        noise = random_batch(batch_size,Fake_Images)
        labels=(np.zeros(batch_size))
        labels=labels.reshape(-1,)
        noise=noise.reshape(-1, 304, 304, 1)

        # Generate images
        generated_images = generator.predict(noise)
        g_loss=discriminator(noise,labels)

        # We try to mislead the discriminator by giving the opposite labels
        g_loss += gan.train_on_batch(noise, labels)
        time_elapsed = time.time() - start_time

        # Display and plot the results
        print("Batch " + str(batch_number + 1) + "/" + str(number_of_batches) + " generator loss | discriminator loss : " + str(g_loss) + " | " + 
                  str(d_loss) + ' - batch took ' + str(time_elapsed) + ' s.')


def main():

    batch_size = 16
    image_shape = (304, 304, 1)
    labels=(16,)
    epochs = 200
    train_dcgan(batch_size, epochs,image_shape,labels)

if __name__ == "__main__":
    main()  

它给我错误'tuple'对象没有属性'ndim'

0 个答案:

没有答案