使用生成对抗神经网络从潜在变量生成图像

时间:2018-12-21 18:32:03

标签: python tensorflow deep-learning deconvolution generative-adversarial-network

我需要构建一个深度神经网络,该网络将两个潜在变量的值作为输入,并生成灰度图像。

我了解这类似于GAN中的生成器网络,但是是否有任何已发布的研究成果或任何专门用于学习此类任务的Python / Tensorflow / Keras代码?

1 个答案:

答案 0 :(得分:1)

因此,这可能是GAN的一项任务,但不一定如此,这取决于您手头的数据。

  

使用GAN生成MNIST样本的玩具问题代码:

# define variables
g_input_shape = 100 
d_input_shape = (28, 28) 
hidden_1_num_units = 500 
hidden_2_num_units = 500 
g_output_num_units = 784 
d_output_num_units = 1 
epochs = 25 
batch_size = 128

# generator
model_1 = Sequential([
    Dense(units=hidden_1_num_units, input_dim=g_input_shape, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
    Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),   
    Dense(units=g_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),
    Reshape(d_input_shape),
])

# discriminator
model_2 = Sequential([
    InputLayer(input_shape=d_input_shape),
    Flatten(),   
    Dense(units=hidden_1_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
    Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),    
    Dense(units=d_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),
])


from keras_adversarial import AdversarialModel, simple_gan, gan_targets
from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling

# Let us compile our GAN and start the training
gan = simple_gan(model_1, model_2, normal_latent_sampling((100,)))
model = AdversarialModel(base_model=gan,player_params=[model_1.trainable_weights, model_2.trainable_weights])
model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), player_optimizers=['adam', 'adam'], loss='binary_crossentropy')

history = model.fit(x=train_x, y=gan_targets(train_x.shape[0]), epochs=10, batch_size=batch_size)

# We get a graph like after training for 10 epochs.
plt.plot(history.history['player_0_loss'])
plt.plot(history.history['player_1_loss'])
plt.plot(history.history['loss'])

# After training for 100 epochs, we can now generate images
zsamples = np.random.normal(size=(10, 100))
pred = model_1.predict(zsamples)
for i in range(pred.shape[0]):
    plt.imshow(pred[i, :], cmap='gray')
plt.show()

即使在动手做这件事之后,您还是应该从reading开始,开始围绕GAN及其适应性发展的研究。

  

注意:

当您拥有如此出色的锤子时,很容易将所有任务视为钉子。

但这不一定很漂亮。当您提供有关问题的更多详细信息时,回答您的问题也容易得多。

  1. 潜在变量看起来如何?
  2. 它们是否与灰度图像配对?
  3. 您有多少数据?规格是什么?