使用张量流概率的贝叶斯神经网络

时间:2020-08-03 11:31:19

标签: deep-learning tensorflow-probability bayesian-deep-learning

我正在尝试使用张量流概率来学习贝叶斯神经网络。我想根据输入特征x_t(即

)来学习响应y_t

y_t = f(x_t) + eps

其中,f(x_t)是神经网络的输出,而eps模型模拟不确定性。第一步,我假设网络的所有权重都具有高斯先验,均值和单位方差为零,而eps建模为均值和单位方差噪声为零。可以使用以下https://colab.research.google.com/github/tensorchiefs/dl_book/blob/master/chapter_08/nb_ch08_03.ipynb

中的示例来实现
kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (x.shape[0] * 1.0)
bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (x.shape[0] * 1.0)

def NLL(y, distr): 
  return -distr.log_prob(y) 

def normal_sp(params): 
  return tfd.Normal(loc=params[:,0:1], scale=1.0)

inputs = Input(shape=(10,))

hidden = tfp.layers.DenseFlipout(10,bias_posterior_fn=tfp.layers.util.default_mean_field_normal_fn(),
                           bias_prior_fn=tfp.layers.default_multivariate_normal_fn,
                           kernel_divergence_fn=kernel_divergence_fn,
                           bias_divergence_fn=bias_divergence_fn,activation="relu")(inputs)
hidden = tfp.layers.DenseFlipout(5,bias_posterior_fn=tfp.layers.util.default_mean_field_normal_fn(),
                           bias_prior_fn=tfp.layers.default_multivariate_normal_fn,
                           kernel_divergence_fn=kernel_divergence_fn,
                           bias_divergence_fn=bias_divergence_fn,activation="relu")(hidden)
params = tfp.layers.DenseFlipout(1,bias_posterior_fn=tfp.layers.util.default_mean_field_normal_fn(),
                           bias_prior_fn=tfp.layers.default_multivariate_normal_fn,
                           kernel_divergence_fn=kernel_divergence_fn,
                           bias_divergence_fn=bias_divergence_fn)(hidden)
dist = tfp.layers.DistributionLambda(normal_sp)(params) 

model_vi = Model(inputs=inputs, outputs=dist)
model_vi.compile(Adam(learning_rate=0.0002), loss=NLL)

然后我可以训练该网络。但是,我希望噪声参数eps的方差为gamma,即

eps〜N(0,sigma)

sigma〜Gamma(a1,b1)

如何在TensorFlow概率框架中实现此目标?我想我需要在最后一个DenseFlipout层上添加另一个神经元,并将先前函数和后函数更改为从正态和伽马分布的乘积中采样的函数。但是不确定确切如何实现。

0 个答案:

没有答案