在 PyTorch 中训练后未保存模型

时间:2021-03-29 07:40:36

标签: pytorch save classification torch

我遇到了以下问题。 我执行递增的交叉验证;我的数据集中有 20 个主题,并尝试对图像进行分类。我从 3 个科目开始,并使用 k=3 进行交叉验证;也就是说,我训练了 3 个不同的模型并验证了遗漏的主题。这就是我为 4、5、...、20 个司机所做的。因此,我训练了很多模型。

现在我想在另一个数据集上检查所有模型的性能,但由于某种原因,所有模型的准确度都相同,这一定是某个地方的错误。

我已经使用了 copy.deepcopy(),所以我肯定在其他地方有错误。 我愿意接受任何提示!

这是训练函数的代码:

def train_model(model, num_classes, dirname, trainloader, valloader, trainset_size, valset_size, criterion, optimizer, scheduler, patience, min_delta, num_epochs, fold):


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
since = time.time()

train_loss, train_acc, val_loss, val_acc = [], [], [], []

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
early_stopping = False
counter = 0
last_train_epoch = 0

for epoch in range(num_epochs):
            
    if early_stopping:
        print('\nEarly Stopping')
        break
        
    print('Epoch {}/{}'.format(epoch+1, num_epochs))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
            dataloader = trainloader
            dataset_size = trainset_size
        else:
            model.eval()   # Set model to evaluate mode
            dataloader = valloader
            dataset_size = valset_size

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for inputs, labels in dataloader:
            model = model.to(device)
            inputs = inputs.to(device)
            #labels = labels.long().to(device)
            labels = labels.to(device) #test_tensor.type(torch.FloatTensor)

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                # zero the parameter gradients
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
            

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_size
        epoch_acc = running_corrects.double() / dataset_size
        
        if phase == 'train':
            train_loss.append(epoch_loss)
            train_acc.append(epoch_acc)
        else:
            val_loss.append(epoch_loss)
            val_acc.append(epoch_acc)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            phase, epoch_loss, epoch_acc))

        # early stopping
        if phase == 'val':
            if counter == patience:
                early_stopping = True
                break
                
            if epoch == 0:
                best_loss = epoch_loss
            else: 
                if best_loss >= epoch_loss + min_delta:
                    print('Validation loss decreased ({:.4f} --> {:.4f}).  Saving model ...'.format(best_loss,epoch_loss))
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(model.state_dict(), '{}/weights/model_fold_{}.pth'.format(dirname,fold))
                    last_train_epoch = epoch + 1
                    best_loss = epoch_loss
                    counter = 0
                else: 
                    counter += 1
            
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

# load best model weights
model.load_state_dict(best_model_wts)
# save best model 

return model, train_acc, train_loss, val_acc, val_loss, last_train_epoch

这是我调用函数的方式:

    model = net
    model = model.to(device)   

    # train 
    [model, train_acc, train_loss, val_acc, val_loss, last_train_epoch] = train_model(model, num_classes, dirname,\
                                                                                     trainloader, valloader,\
                                                                                     trainset_size, valset_size,\
                                                                                     criterion, optimizer, \
                                                                                     exp_lr_scheduler, patience, \
                                                                                     min_delta, num_epochs, fold=val_index)
    
    # test model 
    [preds_val, labels_val, idx_false_val, pred_time_val_fold] = test(model, valloader)
    [preds_tr, labels_tr, idx_false_train, pred_time_train_fold] = test(model, trainloader)
    [preds_all, labels_all, idx_false_all, pred_time_all_fold] = test(model, allloader)
    print('Accuracy on all data: ', accuracy_score(labels_all, preds_all))

为了完整起见,test() 函数如下所示:

def test(model, dataloader):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")

pred_labels, gt_labels, idx_false, pred_time = [], [], [], []

was_training = model.training
model.eval()

with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        start_pred = time.clock() 
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        end_pred = time.clock()
        
        pred_time.append(end_pred-start_pred)

        for j in range(inputs.size()[0]):
            pred_labels.append(preds[j].item())  
            gt_labels.append(labels[j].item())   

    for i in range(len(pred_labels)):
        if pred_labels[i] != gt_labels[i]:
            idx_false.append(i)
             
    model.train(mode=was_training)
    
return pred_labels, gt_labels, idx_false, pred_time

编辑:看起来好像总是保存相同的模型,即使我试图确保只保存最佳模型的更新权重。

0 个答案:

没有答案