我有4个图形卡,我想利用这些图形卡进行火炬传递。 我有这个网:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
如何在网上使用它们?
答案 0 :(得分:1)
您可以使用torch.nn.DataParallel在许多工作人员中分发模型。
只需将您的网络(torch.nn.Module
)传递给它的构造函数,然后像往常一样使用转发。您还可以通过为device_ids
提供List[int]
或torch.device
来指定应该在哪些GPU上运行。
仅出于代码考虑:
import torch
# Your network
network = Net()
torch.nn.DataParallel(network)
# Use however you wish
network.forward(data)