将类信息添加到keras网络

时间:2018-06-18 11:56:42

标签: python keras conv-neural-network loss-function generative-adversarial-network

我试图找出如何使用Generative Adversarial Networks的数据集标签信息。我正在尝试使用can be found here的条件GAN的以下实现。我的数据集包含两个不同的图像域(真实对象和草图),具有公共类信息(椅子,树,橙等)。我选择了这种实现,它只将两个不同的域视为对应的不同“类”(列车样本X对应于真实图像,而目标样本y对应于草图图像)。

有没有办法修改我的代码并考虑我整个架构中的类信息(主席,树等)?我希望我的鉴别器能够预测我生成的图像来自生成器是否属于特定类,而不仅仅是它们是否真实。实际上,在当前架构中,系统学会在所有情况下创建类似的草图。

更新:鉴别器返回一个大小为1x7x7的张量,然后在计算损失之前,y_truey_pred都会通过展平层传递:

def discriminator_loss(y_true, y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

和鉴别器在发生器上的损失功能:

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

此外,我修改了输出1层的鉴别器模型:

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))

现在鉴别器输出1层。如何相应修改上述损失函数?我应该有7而不是1,n_classes = 6 +一类来预测实对和假对吗?

2 个答案:

答案 0 :(得分:7)

建议的解决方案

重用repository you shared中的代码,以下是一些建议的修改,以沿生成器和鉴别器训练分类器(不影响其体系结构和其他损失):

from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D

def lenet_classifier_model(nb_classes):
    # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
    # Replace with your favorite classifier...
    model = Sequential()
    model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(180, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(100, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes, activation='softmax', init='he_normal'))

def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
    inputs = Input((IN_CH, img_cols, img_rows))
    x_generator = generator(inputs)

    merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
    discriminator.trainable = False
    x_discriminator = discriminator(merged)

    classifier.trainable = False
    x_classifier = classifier(x_generator)

    model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])

    return model


def train(BATCH_SIZE):
    (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
    discriminator = discriminator_model()
    generator = generator_model()
    classifier = lenet_classifier_model(6)
    generator.summary()
    discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
        generator, discriminator, classifier)
    d_optim = Adagrad(lr=0.005)
    g_optim = Adagrad(lr=0.005)
    generator.compile(loss='mse', optimizer="rmsprop")
    discriminator_and_classifier_on_generator.compile(
        loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
        optimizer="rmsprop")
    discriminator.trainable = True
    discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
    classifier.trainable = True
    classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")

    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
        for index in range(int(X_train.shape[0] / BATCH_SIZE)):
            image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
            label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here

            generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image * 127.5 + 127.5
                image = np.swapaxes(image, 0, 2)
                cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
                # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")

            # Training D:
            real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
                                        axis=1)
            fake_pairs = np.concatenate(
                (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
            X = np.concatenate((real_pairs, fake_pairs))
            y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            discriminator.trainable = False

            # Training C:
            c_loss = classifier.train_on_batch(image_batch, label_batch)
            print("batch %d c_loss : %f" % (index, c_loss))
            classifier.trainable = False

            # Train G:
            g_loss = discriminator_and_classifier_on_generator.train_on_batch(
                X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
                [image_batch, np.ones((10, 1, 64, 64)), label_batch])
            discriminator.trainable = True
            classifier.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss[1]))
            if index % 20 == 0:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)

理论细节

我相信对于条件GAN的工作方式以及这种方案中的歧视角色有什么误解。

鉴别者的作用

在GAN训练的最小-最大游戏[4]中,鉴别器D与生成器G(您实际上关心的网络)对战,因此在D下的审查,G在输出实际结果方面变得更好。

为此,D受过训练,可以区分来自G的样本中的真实样本;而G受过训练,可以根据目标分布生成逼真的结果/结果来欺骗D

  

注意:在条件GAN的情况下,即GAN将输入样本从一个域A(例如,真实图片)映射到另一个域B   (例如,草图),D通常被馈送成对堆叠的样本对   并必须区分“真实”对(来自A的输入样本+   B中对应的目标样本)和“假”对(输入样本)   来自A和来自G的相应输出) [1、2]

针对D训练条件生成器(与仅训练G,仅损失L1 / L2的情况(例如DAE)相反)提高了G的采样能力,迫使其输出清晰逼真的结果,而不是尝试平均分配。

即使辨别器可以具有多个子网来覆盖其他任务(请参阅以下段落),D也应保留至少一个子网/输出来覆盖其主要任务:告诉真实样本分开生成。要求D与其他语义信息(例如类)一起回归可能会干扰此主要目的。

  

注意:D输出通常不是简单的标量/布尔值。通常有一个鉴别器(例如PatchGAN [1,2])返回一个矩阵   概率,评估其输入如何制作出逼真的补丁   是。


条件GAN

传统GAN以无监督的方式进行训练,以从随机噪声向量作为输入生成真实数据(例如图像)。 [4]

如前所述,条件GAN进一步输入了条件 。它们沿着噪声矢量(而不是噪声矢量)从域A输入样本,并从域B返回相应的样本。 A可以是完全不同的形式,例如B = sketch imageA = discrete labelB = volumetric dataA = RGB image等。[3]

此类GAN也可以通过多个输入来调节,例如A = real image + discrete label,而B = sketch image。引入这种方法的著名著作是 InfoGAN [5]。它介绍了如何使用更高级的鉴别器来对多个连续或离散输入(例如A = digit class + writing typeB = handwritten digit image)上的GAN进行条件处理,该鉴别器具有第二项任务来强制G来最大化条件输入和相应输出之间的相互信息。


最大化cGAN的相互信息

InfoGAN鉴别器具有2个负责人/子网,可以覆盖其2个任务[5]:

  • 一个人的头D1进行传统的真实/生成的区分-G必须最小化此结果,即必须愚弄D1以使其无法分辨真实形式生成的数据;
  • 另一个头D2(也称为Q网络)试图回归输入的A信息-G必须最大化此结果,即它必须输出数据可以“显示”所请求的语义信息(参见G条件输入与其输出之间的互信息最大化)。

例如,您可以在这里找到Keras实现:https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan

通过使用提供的标签并最大化这些输入和G输出之间的相互信息,许多工作正在使用类似的方案来改进对GAN生成的控制[6,7]。基本思想始终是相同的:

  • 在给定域G的某些输入的情况下,训练B来生成域A的元素;
  • 训练D来区分“真实” /“假”结果-G必须将其最小化;
  • 训练Q(例如,一个分类器;可以与D共享图层)以估计来自A个样本的原始B输入-G必须最大化这一点。)

总结

就您而言,您似乎具有以下培训数据:

  • 真实图像Ia
  • 相应的素描图像Ib
  • 相应的类标签c

您想训练生成器G,以便在给定图像Ia及其类标签c的情况下,它会输出适当的草图图像Ib'

总而言之,这是您拥有的很多信息,您可以监督条件图像和条件标签上的培训... 受到前述方法[1、2、5、6、7]的启发,这是一种使用所有这些信息来训练条件G的可能方法:

网络G
  • 输入:Ia + c
  • 输出:Ib'
  • 架构:由您决定(例如U-Net,ResNet等)
  • 损失:Ib'Ib之间的L1 / L2损失,-D损失,Q损失
网络D
  • 输入:Ia + Ib(真实对),Ia + Ib'(伪对)
  • 输出:“伪造”标量/矩阵
  • 架构:由您自己决定(例如PatchGAN)
  • 损失:“假性”估计的交叉熵
网络Q
  • 输入:Ib(用于训练Q的真实样本),Ib'(通过G反向传播时的伪样本)
  • 输出:c'(估计类别)
  • 架构:由您自己决定(例如LeNet,ResNet,VGG等)
  • 损失:cc'之间的交叉熵
培训阶段:
  1. 在一批真实对D + Ia上训练Ib,然后在一批假对Ia + Ib'上训练;
  2. 在一批真实样本Q上训练Ib
  3. 固定DQ的权重;
  4. 训练G,将其生成的输出Ib'传递到DQ,以通过它们向后传播。
  

注意:这是一个非常粗糙的体系结构描述。我建议您阅读文献([1、5、6、7]作为一个好的开始)以获取   更多细节,也许还有更精致的解决方案。


参考文献

    伊索拉(Isola),菲利普(Phillip)等人。 “使用条件对抗网络进行图像到图像的翻译。” arXiv预印本(2017)。 http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf 朱丽君等。 “使用周期一致的对抗网络进行不成对的图像到图像的翻译。” arXiv预印本arXiv:1703.10593(2017)。 http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
  1. Mirza,Mehdi和Simon Osindero。 “有条件的对抗网络。” arXiv预印本arXiv:1411.1784(2014)。 https://arxiv.org/pdf/1411.1784
  2. Goodfellow,Ian等人。 “生成对抗网络。”神经信息处理系统的进步。 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
  3. 陈曦等。 “ Infogan:通过最大化生成对抗网络的信息来进行可解释的表示学习。”神经信息处理系统的进展。 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
  4. 李,敏赫和俊熙。 “可控的生成对抗网络。” arXiv预印本arXiv:1708.00598(2017)。 https://arxiv.org/pdf/1708.00598.pdf
  5. Odena,Augustus,Christopher Olah和Jonathon Shlens。 “使用辅助分类器gans进行条件图像合成。” arXiv预印本arXiv:1610.09585(2016)。 http://proceedings.mlr.press/v70/odena17a/odena17a.pdf

答案 1 :(得分:3)

您应该修改区分模型,使其具有两个输出或具有“ n_classes + 1”输出。

警告:我在您的鉴别器的定义中看不到它输出'true / false',我看到它输出图像...

它应该包含GlobalMaxPooling2DGlobalAveragePooling2D的某个地方。
最后是一层或多层Dense用于分类。

如果说出是非题,则最后一个Dense应该有1个单位。
否则为n_classes + 1个单位。

因此,鉴别符的结尾应类似于

...GlobalMaxPooling2D()...
...Dense(someHidden,...)...
...Dense(n_classes+1,...)...

鉴别器现在将输出n_classes加上“ true / fake”符号(您将无法在其中使用“ categorical”)或什至是“ fake class”(然后将其他类归零使用分类)

您生成的草图应与将伪类与其他类串联在一起的目标一起传递给鉴别器。

选项1-使用“真/假”符号。 (不要使用“ categorical_crossentropy”)

#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses

targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)

#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = originalClasses

targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

选项2-使用“假类”(可以使用“ categorical_crossentropy”):

#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses

targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)

#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = np.zeros((total_fake_sketches, n_classes))

targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

现在将所有内容连接到一个目标数组中(与输入草图有关)

更新的训练方法

对于这种训练方法,您的损失函数应为以下之一:

  • discriminator.compile(loss='binary_crossentropy', optimizer=....)
  • discriminator.compile(loss='categorical_crossentropy', optimizer=...)

代码:

for epoch in range(100):
    print("Epoch is", epoch)
    print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))

    for index in range(int(X_train.shape[0]/BATCH_SIZE)):

        #names:
            #images -> initial images, not changed    
            #sketches -> generated + true sketches    
            #classes -> your classification for the images    
            #isGenerated -> the output of your discriminator telling whether the passed sketches are fake

        batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE)
        trueImages = X_train[batchSlice]

        trueSketches = Y_train[batchSlice] 
        trueClasses = originalClasses[batchSlice]
        trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1)
        trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1)

        fakeSketches = generator.predict(trueImages)
        fakeClasses = originalClasses[batchSlize]             #if option 1 -> telling class + isGenerated - use "binary_crossentropy"
        fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy"    
        fakeIsGenerated = np.ones((len(fakeSketches),))
        fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1)

        allSketches = np.concatenate([trueSketches,fakeSketches],axis=0)            
        allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0)

        d_loss = discriminator.train_on_batch(allSketches, allEndTargets)

        pred_temp = discriminator.predict(allSketches)
        #print(np.shape(pred_temp))
        print("batch %d d_loss : %f" % (index, d_loss))

        ##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models. 
            #you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again   

        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets)
        discriminator.trainable = True


        print("batch %d g_loss : %f" % (index, g_loss[1]))
        if index % 20 == 0:
            generator.save_weights('generator', True)
            discriminator.save_weights('discriminator', True)

正确编译模型

创建“ discriminator”和“ discriminator_on_generator”时:

discriminator.trainable = True
for l in discriminator.layers:
    l.trainable = True


discriminator.compile(.....)

for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]:
    l.trainable = False

discriminator_on_generator.compile(....)