如何在Pytorch中保存和加载随机数生成器状态?

时间:2019-03-11 08:17:09

标签: pytorch random-seed reproducible-research

我正在Pytorch中训练DL模型,并想以确定性的方式训练我的模型。 如this官方指南中所述,我像这样设置随机种子:

np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

现在,我的训练很长,我想保存,然后再加载所有内容,包括RNG。我将torch.savetorch.load_state_dict用于模型和优化程序。

如何保存和加载随机数生成器?

1 个答案:

答案 0 :(得分:1)

您可以使用torch.get_rng_statetorch.set_rng_state

调用torch.get_rng_state时,您将以torch.ByteTensor的形式获取随机数生成器状态。

然后可以将此张量保存在文件中的某个位置,以后可以加载并使用torch.set_rng_state来设置随机数生成器状态。


使用numpy时,您当然可以使用以下命令进行相同操作:
numpy.random.get_statenumpy.random.set_state