我需要构建一个深度神经网络,该网络将两个潜在变量的值作为输入,并生成灰度图像。
我了解这类似于GAN中的生成器网络,但是是否有任何已发布的研究成果或任何专门用于学习此类任务的Python
/ Tensorflow
/ Keras
代码?
答案 0 :(得分:1)
因此,这可能是GAN的一项任务,但不一定如此,这取决于您手头的数据。
使用GAN生成MNIST样本的玩具问题代码:
# define variables
g_input_shape = 100
d_input_shape = (28, 28)
hidden_1_num_units = 500
hidden_2_num_units = 500
g_output_num_units = 784
d_output_num_units = 1
epochs = 25
batch_size = 128
# generator
model_1 = Sequential([
Dense(units=hidden_1_num_units, input_dim=g_input_shape, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
Dense(units=g_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),
Reshape(d_input_shape),
])
# discriminator
model_2 = Sequential([
InputLayer(input_shape=d_input_shape),
Flatten(),
Dense(units=hidden_1_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
Dense(units=d_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),
])
from keras_adversarial import AdversarialModel, simple_gan, gan_targets
from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling
# Let us compile our GAN and start the training
gan = simple_gan(model_1, model_2, normal_latent_sampling((100,)))
model = AdversarialModel(base_model=gan,player_params=[model_1.trainable_weights, model_2.trainable_weights])
model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), player_optimizers=['adam', 'adam'], loss='binary_crossentropy')
history = model.fit(x=train_x, y=gan_targets(train_x.shape[0]), epochs=10, batch_size=batch_size)
# We get a graph like after training for 10 epochs.
plt.plot(history.history['player_0_loss'])
plt.plot(history.history['player_1_loss'])
plt.plot(history.history['loss'])
# After training for 100 epochs, we can now generate images
zsamples = np.random.normal(size=(10, 100))
pred = model_1.predict(zsamples)
for i in range(pred.shape[0]):
plt.imshow(pred[i, :], cmap='gray')
plt.show()
即使在动手做这件事之后,您还是应该从reading开始,开始围绕GAN及其适应性发展的研究。
注意:
当您拥有如此出色的锤子时,很容易将所有任务视为钉子。
但这不一定很漂亮。当您提供有关问题的更多详细信息时,回答您的问题也容易得多。