考虑以tensorflow here实现的Transformer。您将如何保存,然后再在另一个python脚本中还原整个模型,以使用看不见的数据测试模型。以下方法是保存和恢复经过训练的模型的正确方法吗?
1)为了保存模型,我首先训练模型(与原始代码非常相似),然后使用以下代码保存Transformer和优化器:
checkpoint_path = "./checkpoints/trained_model"
ckpt = tf.train.Checkpoint(transformer=transformer,
optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
ckpt_save_path = ckpt_manager.save()
在另一个python脚本(test.py)中,我首先实现了与培训文件相同的所有类。我只删除了实际训练模型的部分(train_step函数和调用该函数的整个循环)。然后,我创建一个Transformer类的对象,并使用以下几行来还原模型:
checkpoint_path = "./checkpoints/trained_model"
ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
谢谢!