Tensorboard恢复训练图

时间:2020-09-10 22:24:09

标签: pytorch tensorboard tensorboardx

我运行了一个强化学习培训脚本,该脚本使用Pytorch并将记录的数据记录到tensorboardX并保存检查点。现在我想继续训练。我如何告诉tensorboardX从我停下来的地方继续呢?谢谢!

2 个答案:

答案 0 :(得分:2)

在pytorch中,张量板与保存实际模型无关。 关于保存和继续培训,请查看saving and loading models的文档。

就是这样

保存

torch.save(the_model.state_dict(), PATH)

加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

答案 1 :(得分:1)

我弄清楚了如何继续训练。创建摘要编写器时,我们需要提供与初次训练时使用的相同的log_dir

from tensorboardX import SummaryWriter
writer = SummaryWriter('log_dir')

然后,在训练循环步骤中,需要从其离开的位置开始(而不是从0开始):

writer.add_scalar('average reward',rewards.mean(),step)