Keras:编译模型后更新“ trainable”属性

时间:2019-04-15 15:56:24

标签: machine-learning keras generative-adversarial-network

我在Keras中有一个条件GAN(CGAN)模型:

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')

if not os.path.exists('images'): os.makedirs('images')

class GAN(object):
  def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
    self.WIDTH = int(width) # width of input images
    self.HEIGHT = int(height) # height of input images
    self.CHANNELS = int(channels) # n color channels in images
    self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
    self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
    self.N_CLASSES = 10 # total number of possible classes in the data
    self.OPTIMIZER = Adam(lr, 0.5)

    # generator
    self.G = self.generator()
    self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

    # discriminator
    self.D = self.discriminator()
    self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
    self.D.trainable = False # prevent stacked D from training; https://github.com/eriklindernoren/Keras-GAN/issues/73

    # stacked generator + discriminator
    self.stacked_G_D = self.stacked_G_D()
    self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

  def generator(self):
    noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs 
    label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class

    # embed label in size of latent dimension
    h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
    label_embedding = Flatten()(h)

    # unified model
    h = multiply([noise, label_embedding])
    h = Dense(256)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(1024)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
    o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)

    model = Model(inputs=[noise, label], outputs=[o])
    model.summary()
    return model

  def discriminator(self):
    image = Input((self.SHAPE))
    label = Input((1,), dtype='int32')

    # embed the label in the shape of an image (flattened)
    h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
    label_embedding = Flatten()(h)

    # parse out the image
    img = Flatten()(image)

    # unified model
    h = multiply([img, label_embedding])
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    o = Dense(1, activation='sigmoid')(h)

    model = Model(inputs=[image, label], outputs=[o])
    model.summary()
    return model

  def stacked_G_D(self):
    noise = Input((self.LATENT_DIM,)) # noise input
    label = Input((1,)) # conditional input
    img = self.G([noise, label])
    valid = self.D([img, label])
    model = Model(inputs=[noise, label], outputs=[valid])
    model.summary()
    return model

  def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100):
    for i in range(epochs):

      # train the discriminator
      idx = np.random.randint(0, X_train.shape[0], batch)
      imgs, labels = X_train[idx], Y_train[idx]
      noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
      fake_imgs = self.G.predict([noise, labels])
      d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
      d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
      d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5

      # train the generator
      sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
      g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))

      if i % save_interval == 0: 
        print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
        filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
        self.plot_images(save_to_disk=True, filename=filename)

  def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
    if not filename: filename = 'mnist.png'
    noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
    classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
    images = self.G.predict([noise, classes])
    cols = np.ceil(n_images/rows) # n_cols in grid
    fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))    
    for i in range(n_images):
      ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
      image = np.reshape(images[i], [28, 28])  
      plt.imshow(image)
    fig.subplots_adjust(hspace=0, wspace=0)
    if save_to_disk:
      fig.savefig(os.path.join('images', filename))
      plt.close('all')
    else:
      fig.show()


(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train)

我的目标是定期冻结鉴别器,使其无法学习。 (这是一些实验性的工作。)但是,在编译模型之后,我找不到真正更新.trainable的{​​{1}}属性的方法。我已经尝试过定期手动更改属性,但是不管区分者继续学习什么。

在编译模型后,实际上是否可以更新模型的gan.D属性?如果是这样,我将不胜感激,以一个简单的示例来说明如何实现这一目标!

1 个答案:

答案 0 :(得分:0)

嗯,您可以在编译模型后更新模型上的.trainable属性,只需重新编译模型即可:

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')

if not os.path.exists('images'): os.makedirs('images')

class GAN(object):
  def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
    self.WIDTH = int(width) # width of input images
    self.HEIGHT = int(height) # height of input images
    self.CHANNELS = int(channels) # n color channels in images
    self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
    self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
    self.N_CLASSES = 10 # total number of possible classes in the data
    self.OPTIMIZER = Adam(lr, 0.5)

    # generator
    self.G = self.generator()
    self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

    # discriminator
    self.D = self.discriminator()
    self.D.trainable = False # normally this line follows the initial compilation of the D
    self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])

    # stacked generator + discriminator
    self.stacked_G_D = self.stacked_G_D()
    self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

  def generator(self):
    noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs 
    label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class

    # embed label in size of latent dimension
    h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
    label_embedding = Flatten()(h)

    # unified model
    h = multiply([noise, label_embedding])
    h = Dense(256)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(1024)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
    o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)

    model = Model(inputs=[noise, label], outputs=[o])
    model.summary()
    return model

  def discriminator(self):
    image = Input((self.SHAPE))
    label = Input((1,), dtype='int32')

    # embed the label in the shape of an image (flattened)
    h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
    label_embedding = Flatten()(h)

    # parse out the image
    img = Flatten()(image)

    # unified model
    h = multiply([img, label_embedding])
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    o = Dense(1, activation='sigmoid')(h)

    model = Model(inputs=[image, label], outputs=[o])
    model.summary()
    return model

  def stacked_G_D(self):
    noise = Input((self.LATENT_DIM,)) # noise input
    label = Input((1,)) # conditional input
    img = self.G([noise, label])
    valid = self.D([img, label])
    model = Model(inputs=[noise, label], outputs=[valid])
    model.summary()
    return model

  def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100, toggle_D_trainable=None):
    for i in range(epochs):

      # train the discriminator
      idx = np.random.randint(0, X_train.shape[0], batch)
      imgs, labels = X_train[idx], Y_train[idx]
      noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
      fake_imgs = self.G.predict([noise, labels])
      d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
      d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
      d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5

      # train the generator
      sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
      g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))

      if i % save_interval == 0: 
        print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
        filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
        self.plot_images(save_to_disk=True, filename=filename)
      if i > 0 and toggle_D_trainable and i % toggle_D_trainable == 0:
        self.D.trainable = False if self.D.trainable else True
        self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])

  def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
    if not filename: filename = 'mnist.png'
    noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
    classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
    images = self.G.predict([noise, classes])
    cols = np.ceil(n_images/rows) # n_cols in grid
    fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))    
    for i in range(n_images):
      ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
      image = np.reshape(images[i], [28, 28])  
      plt.imshow(image)
    fig.subplots_adjust(hspace=0, wspace=0)
    if save_to_disk:
      fig.savefig(os.path.join('images', filename))
      plt.close('all')
    else:
      fig.show()


(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train, save_interval=100, toggle_D_trainable=1000)