PyTorch Lightning将张量移动到validation_epoch_end中的正确设备

时间:2020-07-08 17:19:00

标签: python pytorch pytorch-lightning

我想用validation_epoch_end的{​​{1}}方法创建一个新的张量。官方docs(第48页)指出,我们应该避免直接进行LightningModule.cuda()的呼叫:

没有.cuda()或.to()调用。 。 。闪电帮你做这些。

我们鼓励使用.to(device)方法转移到正确的设备上。

type_as

但是,在步骤new_x = new_x.type_as(x.type())中,我没有任何张量以干净的方式从设备复制张量(通过validation_epoch_end方法)。

我的问题是,如果我想用这种方法创建一个新的张量并将其转移到模型所在的设备上,该怎么办?

我唯一能想到的就是在type_as字典中找到一个张量,但是感觉有点混乱:

outputs

有什么干净的方法可以实现这一目标吗?

1 个答案:

答案 0 :(得分:3)

您是否检查了链接文档中的3.4部分(第34页)?

LightningModules知道它们在什么设备上!直接在设备上构造张量,以避免CPU->设备传输

t = tensor.rand(2, 2).cuda()# bad
(self is lightningModule)t = tensor.rand(2,2, device=self.device)# good 

我有一个类似的问题来创建张量,这对我有帮助。希望对您有帮助。