回归任务的贝叶斯卷积神经网络(变分elbo损失)的TF实现

时间:2020-05-21 14:52:37

标签: python tensorflow keras bayesian-networks

我试图为张量流中的回归任务实现贝叶斯卷积神经网络(变分推断ELBO损失)。根据TensorFlow概率作者(https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/bayesian_neural_network.py)编写的此示例,我假设“ Keras API可以自动将Kullback-Leibler散度(包含在模型的各个翻转层中)自动添加到negloglik损失中计算(负)证据下界损失(ELBO)”。

尽管以下实现在训练过程中未返回任何错误,但是模型的平均绝对错误(mae)是原始cnn的mae值的两倍。有人可以告诉我这个实现有什么问题吗?

代码如下,

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from tensorflow.keras import Model
from tensorflow.keras.layers import Input

kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(N_training_examples, dtype=tf.float32))

def createModel(patchSize):
    input_tensor = Input(shape=(patchSize[0], patchSize[1],patchSize[2]))
    x = tfp.layers.Convolution2DFlipout(8, (3, 3),  kernel_divergence_fn=kl_divergence_function, strides = 4, activation = 'relu')(input_tensor)
    x = tfp.layers.Convolution2DFlipout(16, (3, 3),  kernel_divergence_fn=kl_divergence_function, strides = 4, activation = 'relu')(x)
    x = tf.keras.layers.Flatten()(x)
    x = tfp.layers.DenseFlipout(512, kernel_divergence_fn=kl_divergence_function,)(x)
    x = tfp.layers.DenseFlipout(256, kernel_divergence_fn=kl_divergence_function,)(x)
    x = tfp.layers.DenseFlipout(128, kernel_divergence_fn=kl_divergence_function,)(x)
    x = tfp.layers.DenseFlipout(units=2,activation='linear', kernel_divergence_fn=kl_divergence_function)(x)
    output = tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t[..., :1],
                           scale=1e-3 + tf.math.softplus(0.01 * t[...,1:])))(x)
    cnn = Model(input_tensor, output, name=sModelName)
    return cnn

input_size = [121, 145, 6] 
model = createModel(input_size)

adm = tf.keras.optimizers.Adam(0.0001)
negloglik = lambda y, p_y: -p_y.log_prob(y)

model.compile(loss=negloglik, optimizer=adm )
model.summary()

model.fit(x_train,y_train, epochs=20,steps_per_epoch=int(N_training_examples/batch_size))

系统信息

tensorflow 1.15.2
TensorFlow Probability 0.7.0
google colab

任何帮助将不胜感激。谢谢!

0 个答案:

没有答案