拆分数据集时是否可以为torch.utils.data.random_split()
固定种子,以便可以重现测试结果?
答案 0 :(得分:6)
正如您从 documentation 中看到的,可以将生成器传递给 random_split
random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
答案 1 :(得分:1)
您可以使用torch.manual_seed
函数在全局范围内播种脚本:
import torch
torch.manual_seed(0)
有关更多信息,请参见reproducibility documentation。
如果要专门为种子torch.utils.data.random_split
设置种子,则可以在之后将其“重置”为初始值。只需像这样使用torch.initial_seed()
:
torch.manual_seed(torch.initial_seed())
AFAIK pytorch
不提供诸如seed
或random_state
之类的参数(例如,可以在sklearn
中看到)。