我正在将WGAN-GP扩展为有条件的代码库,可在此处找到: https://github.com/eriklindernoren/Keras-GAN/blob/master/wgan_gp/wgan_gp.py
训练模型时,它似乎不受标签的限制。这就是我建立模型的方式。
Test = ['ASDFGH', 'QWERTYU', 'ZXCVB']
Ref = ['ASDFGY', 'QWERTYI', 'ZXCAA']
from collections import Counter
def comparer(x, y, n):
return (len(x) == len(y)) and (sum(i != j for i, j in zip(x, y)) <= n)
res = [a for a, b in zip(Ref, Test) if comparer(a, b, 1)]
print(res)
['ASDFGY', 'QWERTYI']
绘制模型的结果为:
我不知道如何解释右侧的合并箭头。标签应串联在鉴别器中。我感觉这条线弄乱了一些东西:
# The generator takes noise and the target label (states) as input
# and generates the corresponding samples of that label
noise = Input(shape=(self.latent_size, ), name="noise")
label = Input(shape=(self.label_size, ), name="labels")
real_samples = Input(shape=(self.input_size,), name="real")
self.discriminator = self.build_discriminator()
self.generator = self.build_generator([noise, label])
# First we train the discriminator
self.generator.trainable = False
fake_samples = self.generator([noise, label])
fake = self.discriminator([fake_samples, label])
valid = self.discriminator([real_samples, label])
interpolated = Lambda(self.random_weighted_average)([real_samples, fake_samples])
valid_interp = self.discriminator([interpolated, label])
self.d_model = Model([real_samples, noise, label],
[valid, fake, valid_interp],
name="discriminator")
# Time to train the generator
self.discriminator.trainable = False
self.generator.trainable = True
noise_gen = Input(shape=(self.latent_size,), name="noise_gen")
fake_samples = self.generator([noise_gen, label])
valid = self.discriminator([fake_samples, label])
self.g_model = Model([noise_gen, label], valid, name="generator")
self.g_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
由于我只是传递标签,而且我不知道Keras如何将输入路由到其他输出。