如何在张量流的急切执行模式下保存和恢复模型?

时间:2019-12-10 18:45:24

标签: tensorflow eager-execution

考虑以tensorflow here实现的Transformer。您将如何保存,然后再在另一个python脚本中还原整个模型,以使用看不见的数据测试模型。以下方法是保存和恢复经过训练的模型的正确方法吗?

1)为了保存模型,我首先训练模型(与原始代码非常相似),然后使用以下代码保存Transformer和优化器:

checkpoint_path = "./checkpoints/trained_model"
ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
ckpt_save_path = ckpt_manager.save()

在另一个python脚本(test.py)中,我首先实现了与培训文件相同的所有类。我只删除了实际训练模型的部分(train_step函数和调用该函数的整个循环)。然后,我创建一个Transformer类的对象,并使用以下几行来还原模型:

checkpoint_path =  "./checkpoints/trained_model"
ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)

谢谢!

0 个答案:

没有答案