在TF2中,不使用tf.keras API时如何保存模型/权重?

时间:2019-05-28 14:48:29

标签: tensorflow tensorflow2.0

在文档中,它们似乎专注于如何保存和恢复tf.keras.models,但是我想知道如何通过一些基本的迭代循环来保存和恢复经过定制训练的模型?

现在没有图或会话,我们如何保存在不使用层抽象的情况下自定义构建的tf函数中定义的结构?

1 个答案:

答案 0 :(得分:1)

您可以使用Tensorflow 1.x中的类似方式进行操作-通过使用检查点对象以及Tensorflow 2.0中引入的新闻,即检查点管理器。

ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
  print("Restored from {}".format(manager.latest_checkpoint))
else:
  print("Initializing from scratch.")

for example in toy_dataset():
  loss = train_step(net, example, opt)

您可以看一下Training checkpoints guide