我正在使用Keras建模VAE。我的完整笔记本可在https://gist.github.com/v-i-s-h/fdabcb3d85b89ade95758bd014b307f1
中找到仅记录最终损失值(重建损失总和和KL div损失)来训练模型的代码,我有以下代码段
vae = Model( inputs, outputs, name = "vae" )
# VAE loss = mse_loss OR kl_loss+xent_loss
# reconLoss = mse( K.flatten(inputs), K.flatten(outputs) )
reconLoss = binary_crossentropy( K.flatten(inputs), K.flatten(outputs) )
reconLoss *= imageSize*imageSize # Because binary_crossentropy divides by N
klLoss = 1 + zLogVar - K.square(zMean) - K.exp(zLogVar)
klLoss = K.sum( klLoss, axis = -1 )
klLoss *= -0.5
vaeLoss = K.mean( reconLoss + klLoss )
vae.add_loss( vaeLoss )
vae.compile( optimizer = 'adam' )
vae.summary()
plot_model( vae, to_file = 'vae_model.png', show_shapes = True )
# Callback
import tensorflow.keras.callbacks as cb
class PlotResults( cb.Callback ):
def __init__( self, models, data, batch_size, model_name ):
self.models = models
self.data = data
self.batchSize = batch_size
self.model_name = model_name
self.epochCount = 0
super().__init__()
def on_train_begin( self, log = {} ):
self.epochCount = 0
plot_results( models, data, batch_size = self.batchSize, epochCount = self.epochCount )
def on_epoch_end( self, batch, logs = {} ):
# print( logs )
self.epochCount += 1
plot_results( self.models, self.data, batch_size = self.batchSize, epochCount = self.epochCount )
cbPlotResults = PlotResults( models, data, batchSize, "." )
trainLog = vae.fit( xTrain,
epochs = epochs,
batch_size = batchSize,
validation_data = (xTest,None),
callbacks = [cbPlotResults] )
这样,该模型似乎正在接受训练(请参阅链接的笔记本中的图:https://gist.github.com/v-i-s-h/fdabcb3d85b89ade95758bd014b307f1),一切按预期进行。
现在,我想在训练过程中监视个体重建损失功能以及kl-div损失。为此,将代码修改为此
vae2 = Model( inputs, outputs, name = "vae2" )
# ====================== CHANGE ==================================
def fn_reconLoss( x, x_hat ):
# reconLoss = mse( K.flatten(inputs), K.flatten(outputs) )
reconLoss = binary_crossentropy( K.flatten(x), K.flatten(x_hat) )
reconLoss *= imageSize*imageSize # Because binary_crossentropy divides by N
return reconLoss
def fn_klLoss( x, x_hat ):
klLoss = 1 + zLogVar - K.square(zMean) - K.exp(zLogVar)
klLoss = K.sum( klLoss, axis = -1 )
klLoss *= -0.5
return klLoss
def fn_vaeloss( x, x_hat ):
return K.mean(fn_reconLoss(x,x_hat) + fn_klLoss(x,x_hat))
# ====================================================================
# vae2.add_loss( fn_vaeloss )
vae2.compile( optimizer = 'adam', loss=fn_vaeloss, metrics = [fn_reconLoss,fn_klLoss] )
vae2.summary()
plot_model( vae2, to_file = 'vae2_model.png', show_shapes = True )
# Callback
import tensorflow.keras.callbacks as cb
class PlotResults( cb.Callback ):
def __init__( self, models, data, batch_size, model_name ):
self.models = models
self.data = data
self.batchSize = batch_size
self.model_name = model_name
self.epochCount = 0
super().__init__()
def on_train_begin( self, log = {} ):
self.epochCount = 0
plot_results( models, data, batch_size = self.batchSize, epochCount = self.epochCount )
def on_epoch_end( self, batch, logs = {} ):
# print( logs )
self.epochCount += 1
plot_results( self.models, self.data, batch_size = self.batchSize, epochCount = self.epochCount )
cbPlotResults = PlotResults( models, data, batchSize, "." )
trainLog = vae2.fit( xTrain,
epochs = epochs,
batch_size = batchSize,
validation_data = (xTest,xTest),
callbacks = [cbPlotResults] )
因此,该模型学习错误。即使损失似乎有所减少,但重建完全没有用(请参见https://gist.github.com/v-i-s-h/fdabcb3d85b89ade95758bd014b307f1最后一块)。
在第二个块中定义损失函数时,我无法弄清楚哪里出了错误。 是正确的方法吗?