我创建了一个VAE
类,并使用它来训练MNIST数据集。
训练VAE后,我尝试使用原始图像来获取重建的图像
y_pred = VAE_model.predict(np.expand_dims(np.expand_dims(X[0],axis=0),axis=-1))
这就是我所得到的
然后,我尝试使用相同的原始图像,并传递给编码器以获取潜在表示z
,然后将此z
传递给解码器以获取重构的图像。
z = VAE_model.get_latent_points(np.expand_dims(np.expand_dims(X[0],axis=0),axis=-1))
z_img = VAE_model.generate(z)
但是,重建与之前的图像不同。
为什么不同?
import numpy as np
from keras.datasets import mnist
from keras.layers import Input, Lambda, Conv2D, Conv2DTranspose, BatchNormalization, Dense, Reshape, Flatten,Activation
from keras.models import Model, Sequential
import keras.backend as K
(X, y_train), (x, y_test) = mnist.load_data()
def sample_z(args):
mu, log_sigma = args
n_z = int(mu.shape[-1])
eps = K.random_normal(shape=(n_z,), mean=0., stddev=1.)
return mu + K.exp(log_sigma / 2) * eps
def create_encoder_base(input_dim,f1=32):
model = Sequential(name='encoder_base')
model.add(Conv2D(f1,kernel_size=(1,1),strides=(1,1),input_shape=input_dim))
model.add(BatchNormalization())
model.add(Activation('tanh'))
return model
def create_decoder_base(input_dim, f3=1):
model = Sequential(name='decoder_base')
model.add(Conv2DTranspose(f3,(1,1),strides=(1,1),input_shape=input_dim))
model.add(BatchNormalization())
model.add(Activation('tanh'))
return model
def create_encoder_flat(input_dim, n=256):
model = Sequential(name='encoder_flat')
model.add(Flatten(input_shape=input_dim))
model.add(Dense(n))
model.add(BatchNormalization())
model.add(Activation('tanh'))
flat_size = model.layers[0].output_shape[1]
return model, flat_size
def create_decoder_reshape(input_dim, output_dim, flat_shape, n=256):
model = Sequential(name = 'decoder_reshape')
model.add(Dense(n, input_shape=input_dim))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Dense(flat_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape(output_dim))
return model
def create_z_space(input_shape, n_z):
input_layer = Input((input_shape,), name='z_input')
Q_z_mean = Dense(n_z, name='z_mean')
Q_z_log_var = Dense(n_z, name='z_log_var')
z_mean = Q_z_mean(input_layer)
z_log_var = Q_z_log_var(input_layer)
z = Lambda(sample_z)([z_mean,z_log_var])
model = Model(input_layer, z, name='z_space')
middle_model = Model(input_layer, [z_mean,z_log_var], name='middle_model')
return model, middle_model
class VAE():
def __init__(self, input_dim, f1=8, n=100, n_z=10):
### Creating layers ###
inputs = Input(input_dim)
encoder_base = create_encoder_base(input_dim,f1)
encoder_output_shape = decoder_input_dim = (encoder_base.layers[-1].output_shape[1:])
encoder_flat, flat_size = create_encoder_flat(encoder_output_shape, n)
z_space, z_middle = create_z_space(encoder_flat.layers[-1].output_shape[1],n_z)
decoder_reshape = create_decoder_reshape((n_z,), encoder_output_shape, flat_size, n)
decoder_base = create_decoder_base(decoder_input_dim, input_dim[-1])
Q_z_mean = Dense(n_z, name='z_mean')
Q_z_log_var = Dense(n_z, name='z_log_var')
### Connecting layers ###
input_layer = Input(input_dim)
encoder = encoder_base(input_layer) # Encoder
x = encoder_flat(encoder)
z = z_space(x) # Variational part
z_middle(x) # trying to get the output of middle layers by connecting it to previous layers
x = decoder_reshape(z) # decoder part
x = decoder_base(x)
self.z_mean = z_middle.get_output_at(1)[0]
self.z_log_var = z_middle.get_output_at(1)[1]
self.full_model = Model(input_layer, x)
self.encoder_model = Model(input_layer, z)
#Connecting the generator
z_input = Input((int(z.shape[1]),))
_x = decoder_reshape(z_input)
_x = decoder_base(_x)
self.generator_model = Model(z_input, _x)
self.n_z = n_z
def vae_loss(self,y_true, y_pred):
""" Calculate loss = reconstruction loss + KL loss for each data in minibatch """
# E[log P(X|z)]
recon = K.sum(K.square(y_pred - y_true), axis=1) #calcualting mse
# D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian
beta = 0.001
kl = 0.5 * K.sum((K.exp(self.z_log_var) + K.square(self.z_mean) - 1. - self.z_log_var), axis=1)
return recon + beta* kl
def summary(self):
return self.full_model.summary()
def fit(self, X, epochs=10, batch_size=500):
self.full_model.compile(optimizer='adam', loss=self.vae_loss)
self.full_model.fit(X,X,epochs=epochs,batch_size=batch_size)
def predict(self, X):
return self.full_model.predict(X)
def generate(self, z):
return self.generator_model.predict(z)
def get_latent_points(self, X):
return self.encoder_model.predict(X)
def save_weights(self,file_name):
return self.full_model.save_weights(file_name)
def load_weights(self,file_name):
return self.full_model.load_weights(file_name)
def get_models(self):
# get full_model, encoder_model, generator_model
return self.full_model, self.encoder_model, self.generator_model
input_dim = (X.shape[1], X.shape[2],1)
VAE_model = VAE(input_dim)
VAE_model.fit(np.expand_dims(X,axis=-1))
y_pred = VAE_model.predict(np.expand_dims(np.expand_dims(X[0],axis=0),axis=-1))
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(X[0])
plt.imshow(y_pred.reshape(28,28))
z = VAE_model.get_latent_points(np.expand_dims(np.expand_dims(X[0],axis=0),axis=-1))
z_img = VAE_model.generate(z)
plt.imshow(z_img.reshape(28,28))