我正在使用带有多个GPU的Pytorch进行测试。我没有使用DataParallel,但是我想使用Model Parallelism。 我的模型设计是两个输入分支(每个分支在单独的GPU中)。我在MNIST上做了一个可重复性的例子。 但是在培训中,我得到了下面提到的例外。任何帮助将不胜感激。
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Hyperparameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001
DATA_PATH = '/data/'
MODEL_STORE_PATH = '/models/'
# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# MNIST dataset
train_dataset = datasets.MNIST(root=DATA_PATH, train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(root=DATA_PATH, train=False, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
gpu1 = torch.device("cuda:0")
gpu2 = torch.device("cuda:1")
class DistConvNet(nn.Module):
def __init__(self):
super(DistConvNet, self).__init__()
# gpu1
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer1.to(gpu1)
self.layer2.to(gpu1)
# gpu2
self.layer3 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer4 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer3.to(gpu2)
self.layer4.to(gpu2)
self.drop_out = nn.Dropout()
self.fc1 = nn.Linear(7 * 7 * 64 * 2, 1000)
self.fc2 = nn.Linear(1000, 10)
self.drop_out.to(gpu1)
self.fc1.to(gpu1)
self.fc2.to(gpu1)
def forward(self, x1, x2):
out1 = self.layer1(x1)
out1 = self.layer2(out1)
out2 = self.layer3(x2)
out2 = self.layer4(out2)
out1 = out1.reshape(out1.size(0), -1)
out2 = out2.reshape(out2.size(0), -1)
out2.to(gpu1)
out = torch.cat((out1, out2), 1)
out = self.drop_out(out)
out = self.fc1(out)
out = self.fc2(out)
return out
model_dist = DistConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_dist.parameters(), lr=learning_rate)
total_step = len(train_loader)
loss_list = []
acc_list = []
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Run the forward pass
images_gpu1, images_gpu2, labels = images.to(gpu1), images.to(gpu2), labels.to(gpu1)
outputs = model_dist(images_gpu1, images_gpu2)
loss = criterion(outputs, labels)
loss_list.append(loss.item())
# Backprop and perform Adam optimisation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Track the accuracy
total = labels.size(0)
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == labels).sum().item()
acc_list.append(correct / total)
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item(),
(correct / total) * 100))
我收到以下异常:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-27984b0eb824> in <module>
13 # Backprop and perform Adam optimisation
14 optimizer.zero_grad()
---> 15 loss.backward()
16 optimizer.step()
17
/opt/conda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
100 products. Defaults to ``False``.
101 """
--> 102 torch.autograd.backward(self, gradient, retain_graph, create_graph)
103
104 def register_hook(self, hook):
/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
88 Variable._execution_engine.run_backward(
89 tensors, grad_tensors, retain_graph, create_graph,
---> 90 allow_unreachable=True) # allow_unreachable flag
91
92
RuntimeError: Function CatBackward returned an invalid gradient at index 1 - expected device 1 but got 0
感谢您的帮助。