如何在GPU上运行预训练的pytorch模型?

时间:2020-06-19 18:13:02

标签: pytorch

在这里,我尝试使用mobilenetv2移动版在自定义数据集上进行训练。我可以在CPU上运行它,但是我更希望在GPU上运行它。相反,我收到了以下错误:

RuntimeError:后端CPU的预期对象,但参数#2'weight'获得了后端CUDA

RuntimeError:后端CPU的预期对象,但参数#4'mat1获得了后端CUDA

就像我的帖子问的那样,如何才能将预训练的模型运行在GPU上?

MobileNet = models.mobilenet_v2(pretrained = True)
if torch.cuda.is_available():
    MobileNet.cuda()

for param in MobileNet.parameters():
    param.requires_grad = False

    torch.manual_seed(50)

MobileNet.classifier = nn.Sequential(nn.Linear(1280, 1000), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1000,3), nn.LogSoftmax(dim=1))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(MobileNet.classifier.parameters(), lr=0.001)

train_transform = transforms.Compose([
        transforms.RandomRotation(10),      # rotate +/- 10 degrees
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.Resize(224),             # resize shortest side to 224 pixels
        transforms.CenterCrop(224),         # crop longest side to 224 pixels at center
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

train_data = datasets.ImageFolder('C:/Users/mixv/Pictures/Summer/datasets/train', transform=train_transform)
test_data = datasets.ImageFolder('C:/Users/mix/Pictures/Summer/datasets/test', transform=test_transform)


torch.manual_seed(42)
batch=64
train_loader = DataLoader(train_data, batch_size=batch, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch, shuffle=True)

if torch.cuda.is_available():
    train_loader = DataLoader(train_data, batch_size=batch, shuffle=True, pin_memory = True)
    test_loader = DataLoader(test_data, batch_size=batch, shuffle=True, pin_memory = True)

epochs = 10

train_losses = []
test_losses = []
train_correct = []
test_correct = []
start_time =time.time()
for i in range(epochs):
    trn_corr = 0
    tst_corr = 0

    # Run the training batches

    for b, (images, labels) in enumerate(train_loader):

        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        b+=1

        # Apply the model
        y_pred = MobileNet(images)
        loss = criterion(y_pred, labels)

        # Tally the number of correct predictions
        predicted = torch.max(y_pred.data, 1)[1]
        batch_corr = (predicted == labels).sum()
        trn_corr += batch_corr

        accuracy = trn_corr.item()*100/(b*batch)
        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

1 个答案:

答案 0 :(得分:0)

正如RuntimeError所说,某些权重仍在cpu中。我怀疑一个可能的缺陷是MobileNet.cuda()是在 public void getPosts(final postCallback callback) { final FirebaseFirestore db = FirebaseFirestore.getInstance(); CollectionReference postsRef = db.collection("Posts"); Query postsQuery = postsRef.orderBy("createTime", Query.Direction.DESCENDING).limit(20); // Starting the post documents Task<QuerySnapshot> task = postsQuery.get(); task.addOnCompleteListener(new OnCompleteListener<QuerySnapshot>() { @Override public void onComplete(@NonNull Task<QuerySnapshot> task) { if(task.isSuccessful()){ QuerySnapshot querySnapshot = task.getResult(); List<DocumentSnapshot> docsList = querySnapshot.getDocuments(); for(DocumentSnapshot docSnap : docsList){ String userID = docSnap.getString("originalPoster"); // getting user documents Task<DocumentSnapshot> userTask = db.collection("Users").document(userID).get(); userTask.addOnCompleteListener(new OnCompleteListener<DocumentSnapshot>() { @Override public void onComplete(@NonNull Task<DocumentSnapshot> task) { DocumentSnapshot userDoc = task.getResult(); String userID = userDoc.getId(); String firstName = userDoc.getString("first_name"); String surname = userDoc.getString("surname"); User userObject = new User(firstName, userID, surname); // cant call my callback right here otherwise its called for every // completed user fetch } }); // cant call my callback right here since its too early } }else if(task.isCanceled()){ System.out.println("Fetch failed!"); } } }); } 之后完成的,这意味着这些新创建的权重可能没有发送到gpu。尝试颠倒这两个顺序,看看