我有一个VAE tf.keras模型,它由编码器和解码器模型组成。我想在训练VAE(使用add_loss()
)时使用不同的自定义术语(使用VAE.fit()
)对编码器和解码器进行规范化。
使用伪代码:
encoder = model(input, encoding)
decoder = model(encoding, output)
encoder.add_loss(custom_loss_1)
decoder.add_loss(custom_loss_2)
vae = model(input, output)
vae.compile(optimizer, standard_keras_loss_function)
vae.fit(args)
请注意,我无法将自定义损耗添加到包装模型(vae),因为我希望仅在它们各自的子模型中对它们进行反向传播(编码器的损耗不应影响解码器的梯度,反之亦然)。我也不能单独安装编码器和解码器,VAE受益于联合培训编码器和解码器。
我面临两个挑战:
尽管我在训练过程中没有发现错误,但是在拟合条形包装器模型时,如果考虑编码器和解码器的损失,我没有轻松的方法。我是否需要编译编码器和解码器,以考虑到增加的自定义损失?
我不希望将“ standard_keras_loss”(二进制交叉熵)反向传播到解码器之外,它只会影响解码器的梯度。但是,vae.compile()
方法要求我为VAE模型增加至少一个损失。因此,防止这种损失向后传播到编码器的唯一方法是尝试在编码器之后添加一个stop_gradient
lambda层,但这将使编码器与解码器断开连接(实际上,无论我在哪里,都将出现断开连接的图形错误尝试添加stop_gradient
层。)
是否有一种TF.keras
方式无需编写自定义训练循环即可?我正在尝试保持干净的代码。