固定火炬的种子random_split()

时间:2019-04-23 22:36:36

标签: pytorch torch

拆分数据集时是否可以为torch.utils.data.random_split()固定种子,以便可以重现测试结果?

2 个答案:

答案 0 :(得分:6)

正如您从 documentation 中看到的,可以将生成器传递给 random_split

random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

答案 1 :(得分:1)

您可以使用torch.manual_seed函数在全局范围内播种脚本:

import torch
torch.manual_seed(0)

有关更多信息,请参见reproducibility documentation

如果要专门为种子torch.utils.data.random_split设置种子,则可以在之后将其“重置”为初始值。只需像这样使用torch.initial_seed()

torch.manual_seed(torch.initial_seed())

AFAIK pytorch 提供诸如seedrandom_state之类的参数(例如,可以在sklearn中看到)。