将PyTorch设备名称传递给模型的最佳实践

时间:2020-04-29 14:05:25

标签: pytorch

目前,我在深度学习项目中将train.pymodel.py分开了。

因此对于数据集,它们被发送到 epoch for loop 内部的cuda设备,如下所示。

train.py

...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
    s0 = batch_data[0].to(device)
    s1 = batch_data[1].to(device)
    pred = model(s0, s1)

但是,在我的模型内部(在model.py中),它也需要访问设备变量来进行类似方法的跳过连接。制作隐藏单元的新副本(用于剩余连接)

model.py

class MyNet(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(MyNet, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        ...

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device)  # <--
        (some opps for x1 -> skip_conn)
        x = torch.cat((x, skip_conn), 1)

在这种情况下,我目前正在传递device作为参数,但是,我认为这不是最佳实践。

  1. 将数据集发送到CUDA的最佳实践应该在哪里?
  2. 对于需要访问device的多个脚本,我该如何处理? (参数,全局变量?)

1 个答案:

答案 0 :(得分:1)

您可以向MyModel添加新属性以存储device信息,并在skip_conn初始化中使用它。

class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, device): # <--
    super(MyNet, self).__init__()
    self.conv1 = GCNConv(in_feats, hid_feats)
    self.device = device # <--
    self.to(self.device) # <--
    ...

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x1 = copy.copy(x.float())
    x = self.conv1(x, edge_index)
    skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device)  # <--
    (some opps for x1 -> skip_conn)
    x = torch.cat((x, skip_conn), 1)

请注意,在此示例中,MyNet负责所有设备逻辑,包括.to(device)调用。这样,我们将所有与模型相关的设备管理封装在模型类本身中。