Tensorflow概率。 -保存和恢复贝叶斯神经网络的检查点

时间:2019-02-12 00:43:00

标签: tensorflow probability bayesian checkpoint

我一直在研究Tensorflow概率库,并尝试修改bayesian network example中的示例,希望我可以保存检查点然后还原它们。我首先开始尝试使用tf.train.Checkpoint,但是,尽管保存或还原时都没有出现任何错误,但由于准确性完全不同,所以似乎没有从以前的检查点重新开始训练。 然后,我尝试使用tf.keras.models.model.save再次保存文件,但是在尝试还原时,出现错误:ValueError:未知图层:在尝试反序列化图层时使用了Conv2DFlipout。 老实说,如果有人能指出我正确的方向,我不知道该走哪条路。 谢谢! 乔凡娜

这是我到目前为止要还原的内容:

 if FLAGS.architecture == "resnet":
    model_fn = bayesian_resnet.bayesian_resnet
else:
    model_fn = bayesian_vgg.bayesian_vgg

model = model_fn(
  IMAGE_SHAPE,
  num_classes=4,
  kernel_posterior_scale_mean=FLAGS.kernel_posterior_scale_mean,
  kernel_posterior_scale_constraint=FLAGS.kernel_posterior_scale_constraint)
print(images)

#check if saved checkpoint exists
exists = os.path.isfile(FLAGS.model_dir+"checkpoint.hdf5")
if exists:
  model = tf.keras.models.load_model(FLAGS.model_dir+"checkpoint.hdf5") 

logits = model(images)
labels_distribution = tfd.Categorical(logits=logits)

# Perform KL annealing. The optimal number of annealing steps
# depends on the dataset and architecture.
t = tf.Variable(0.0)
#kl_regularizer = t / (FLAGS.kl_annealing * len(x_train) / FLAGS.batch_size)
...

0 个答案:

没有答案