使用python类在Tensorflow中构建GAN

时间:2017-10-10 11:20:58

标签: python tensorflow deep-learning keras tensorflow-gpu

我习惯在Keras设计我的GAN。但是,对于特定需求,我想将我的代码调整为Tensorflow。使用Tensorflow的大多数GAN实现使用 GAN 的类,然后使用鉴别器生成器的函数

这给出了这样的东西:

class MyGAN():
    def __init__(self):
        # various initialisation

    def generator(self, n_examples):
        ### do some business and return n_examples generated.

        return G_output

    def discrimintator(self, images):
        ### do some business with the images

        return D_Prob, D_logits
事实上,这是完全没问题的。但是,我更喜欢这个设计,其中每个部分[MyGAN,Generator,Discriminator]都是一个完整且独立的类。您只初始化主要的一个: MyGAN ,它自己处理其余的。它允许我更简单的代码组织和相对容易的代码阅读。

然而,我在一些设计模式上挣扎,用 Keras 我可以使用“输入”层,这允许我从数据集切换到Discriminator真实数据和由数据集生成的假数据发电机。用Keras 伪代码来展示这个想法只需几行:

class Generator(object):

    def __init__(self, latent_shape):

        gen_input = Input(shape=latent_shape, name='generator_input')

        #### ====== do some business ====== ####

        gen_output = Activation('tanh', name='generator_output')(previous_layer)

        self.model = Model(gen_input, gen_output)

class Discriminator(object):

    def __init__(self):

        disc_input = Input(shape=self.input_shape, name='discriminator_input')

        #### ====== do some business ====== ####

        disc_output = Activation('sigmoid', name='discriminator_output')(previous_layer)

        # Model definition with Functional API
        self.model = Model(disc_input, disc_output)

class MyGAN(object):

    def __init__(self):

        ########## GENERATOR ##########

        # We create the optimizer for G
        g_optim = Adam(lr=2e-4, beta_1=0.5)

        self.generator = Generator(latent_shape=(100,))        
        self.generator.model.compile(loss='binary_crossentropy', optimizer=g_optim)

        ########## DISCRIMINATOR ##########

         We create the optimizer for D
        d_optim = Adam(lr=2e-4, beta_1=0.5)

        self.discriminator = Discriminator()
        self.discriminator.model.compile(loss='binary_crossentropy', optimizer=d_optim, metrics=['accuracy'])

        ########## FULL GAN ##########

        # create an Input Layer for the complete GAN
        gan_input = Input(shape=self.latent_shape)

        # link the input of the GAN to the Generator
        G_output = self.generator.model(gan_input)

         For the combined model we will only train the generator => We do not want to backpropagate D while training G
        self.discriminator.model.trainable = False

        # we retrieve the output of the GAN
        gan_output = self.discriminator.model(G_output)

        # we construct a model out of it.
        self.fullgan_model = Model(gan_input, gan_output)
        self.fullgan_model.compile(loss='binary_crossentropy', optimizer=g_optim, metrics=['accuracy'])

    def train_step(self, batch):

        ## Train Generator First ##

        noise = #### Generate some noise with the size: 2*batch (D is trained twice)        
        loss, acc = self.fullgan_model.train_on_batch(noise, np.ones(noise.shape[0]))

        ## Train Discriminator Then ##

        self.discriminator.model.trainable = True

        generated_images = ### Generate samples with G with same size as batch

        d_loss_fake, d_acc_fake = self.discriminator.model.train_on_batch(
                generated_images,
                np.zeros(generated_images.shape[0])
            )
        d_loss_real, d_acc_real = self.discriminator.model.train_on_batch(
            X, 
            np.ones(X.shape[0])
        )

        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        d_acc  = 0.5 * np.add(d_acc_real,  d_acc_fake)

        self.discriminator.model.trainable = False

我的问题很简单,我怎样才能使用Tensorflow重现这样的代码结构?我有一些想法,但我不相信任何这些:

我可以使用 tf.Variable ,然后使用 load 函数在执行期间分配它。问题是:对于每个训练步骤,似乎我需要为每个网络(D和G)执行两个 sess.run()。这显然效率低下......

  • 对于发电机:

    • 1:使用sess.run()调用
    • 生成带有G的数据
    • 2:使用sess.run()调用
    • 在D中加载数据
    • 3:使用其他sess.run()电话
    • 计算损失
    • 4:最后用最后sess.run()
    • 反向传播G.
  • 对于判别者:

    • 1:使用sess.run()调用
    • 生成带有G的数据
    • 2:使用sess.run()调用
    • 在D中加载数据
    • 3:使用sess.run()调用
    • 计算虚假数据的丢失
    • 4:使用sess.run()调用
    • 计算实际数据的损失
    • 5:最后用最后sess.run()
    • 反向传播D.

对我而言,这看起来效率低下,我没有更好的主意。我当然可以使用占位符,它会使用feed_dict“隐藏”加载操作,但不会影响性能(我试过)。

我的目标如下:

  • 直接将G连接到D并且能够避免调用G,只需将G和D直接连接。

  • 能够在从G或数据批处理中获取数据时“切换D”。这将允许我避免来自GPU / CPU =>的数据传输。节省时间

1 个答案:

答案 0 :(得分:0)

您可以使用纯功能方法并使用可变范围重新应用网络来实现所需的设计结构。例如,此代码段设置网络的真实/虚假部分:

with variable_scope.variable_scope(generator_scope) as gen_scope:
  generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
  discriminator_gen_outputs = discriminator_fn(generated_data,
                                               generator_inputs)
with variable_scope.variable_scope(dis_scope, reuse=True):
  discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

考虑使用TensorFlow's TFGAN来避免重新创建GAN基础架构。 These examples演示如何使用TFGAN创建各种GAN(使用,以及使用内置功能。<​​/ p>