当我在最后一个使用权重裁剪的鉴别器层中使用线性激活或不激活时,鉴别器精度变为1,生成器变为0。如果我删除权重剪辑,则生成器精度将变为1,而鉴别器将变为0,大约进行300次迭代。但是,当我使用S型激活作为鉴别器的最后一层时,会产生削波发生器精度变为1,而没有削波,则发生器的损耗会卡住,而精度会达到0.5左右。
注意-在所有情况下,都会产生结果并显示所有警告:tensorflow:可训练砝码与收集的可训练砝码之间的差异,您是否设置了model.trainable
而不在之后调用model.compile
?
代码在这里给出,请不要介意在任何地方复制和粘贴缩进-
class WGAN():
def __init__(self,
input_dim,
disc_filter,
disc_kernel,
disc_strides,
disc_dropout,
disc_lr,
gen_filter,
gen_kernel,
gen_strides,
gen_upsample,
gen_lr,
z_dim,
batch_size):
self.input_dim = input_dim
self.disc_filter = disc_filter
self.disc_kernel = disc_kernel
self.disc_strides = disc_strides
self.disc_dropout = disc_dropout
self.disc_lr = disc_lr
self.gen_filter = gen_filter
self.gen_kernel = gen_kernel
self.gen_strides = gen_strides
self.gen_upsample = gen_upsample
self.gen_lr = gen_lr
self.z_dim = z_dim
self.batch_size = batch_size
self.weight_init = RandomNormal(mean=0., stddev=0.02)
self.d_losses = []
self.g_losses = []
self.epoch = 0
self.Discriminator()
self.Generator()
self.full_model()
def wasserstein(self, y_true, y_pred):
return -K.mean(y_true * y_pred)
def Discriminator(self):
disc_input = Input(shape=self.input_dim, name='discriminator_input')
x = disc_input
for i in range(len(self.disc_filter)):
x = Conv2D(filters=self.disc_filter[i], kernel_size=self.disc_kernel[i], strides=self.disc_strides[i], padding='same', name='disc_'+str(i))(x)
x = LeakyReLU()(x)
x = Dropout(self.disc_dropout)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
disc_output = Dense(1, activation='sigmoid', kernel_initializer = self.weight_init)(x)
self.discriminator = Model(disc_input, disc_output)
def Generator(self):
gen_input = Input(shape=(self.z_dim,), name='generator_input')
x = gen_input
x = Dense(7*7*self.batch_size, kernel_initializer = self.weight_init)(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Reshape(target_shape=(7,7,self.batch_size))(x)
for i in range(len(self.gen_filter)):
if self.gen_upsample[i]==2:
x = UpSampling2D(size=self.gen_upsample[i], name='upsample_'+str(i/2))(x)
x = Conv2D(filters=self.gen_filter[i], kernel_size=self.gen_kernel[i], strides=self.gen_strides[i], padding='same', name='gen_'+str(i))(x)
else:
x = Conv2DTranspose(filters=self.gen_filter[i], kernel_size=self.gen_kernel[i], strides=self.gen_strides[i], padding='same', name='gen_'+str(i))(x)
if i<len(self.gen_filter)-1:
x = BatchNormalization()(x)
x = LeakyReLU()(x)
else:
x = Activation("tanh")(x)
gen_output = x
self.generator = Model(gen_input, gen_output)
def set_trainable(self, model, val):
model.trainable=val
for l in model.layers:
l.trainable=val
def full_model(self):
### COMPILE DISCRIMINATOR
self.discriminator.compile(optimizer= Adam(self.disc_lr), loss = self.wasserstein, metrics=['accuracy'])
### COMPILE THE FULL GAN
self.set_trainable(self.discriminator, False)
self.discriminator.compile(optimizer= Adam(self.disc_lr), loss = self.wasserstein, metrics=['accuracy'])
model_input = Input(shape=(self.z_dim,), name='model_input')
model_output = self.discriminator(self.generator(model_input))
self.model = Model(model_input, model_output)
self.model.compile(optimizer= Adam(self.disc_lr), loss = self.wasserstein, metrics=['accuracy'])
self.set_trainable(self.discriminator, True)
def train_generator(self, batch_size):
valid = np.ones((batch_size,1))
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
return self.model.train_on_batch(noise, valid)
def train_discriminator(self, x_train, batch_size, using_generator):
valid = np.ones((batch_size,1))
fake = np.zeros((batch_size,1))
if using_generator:
true_imgs = next(x_train)[0]
if true_imgs.shape[0] != batch_size:
true_imgs = next(x_train)[0]
else:
idx = np.random.randint(0, x_train.shape[0], batch_size)
true_imgs = x_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real, d_acc_real = self.discriminator.train_on_batch(true_imgs, valid)
d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
for l in self.discriminator.layers:
weights = l.get_weights()
weights = [np.clip(w, -0.01, 0.01) for w in weights]
l.set_weights(weights)
return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake]
def train(self, x_train, batch_size, epochs, print_every_n_batches = 50, using_generator = False):
for epoch in range(self.epoch, self.epoch + epochs):
d = self.train_discriminator(x_train, batch_size, using_generator)
g = self.train_generator(batch_size)
if self.epoch % print_every_n_batches == 0:
print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch, d[0], d[1], d[2], d[3], d[4], d[5], g[0], g[1]))
self.d_losses.append(d)
self.g_losses.append(g)
self.epoch+=1