如何加载经过训练的模型以推断预测数据

时间:2021-01-22 01:33:01

标签: python validation pytorch conv-neural-network prediction

我训练并保存了 1000 个 epoch 的 CNN 模型,现在想要检索验证数据(预测图像)。在下面的代码中,test_predtest_real 在验证集中输出预测图像和真实图像。我应该加载并运行保存的模型再运行 1 个 epoch 以检索预测图像(这将在 CUDA 内存不足时结束,因为数据很大)?或者还有其他方法吗?您可以在下面看到我的部分代码:

for epoch in range(epochs):
    mse_train_losses= []
    mae_train_losses = []
    N_train = []
    mse_val_losses = []
    mae_val_losses = []
    N_test = []
    
    if save_model:
        if epoch % 50 ==0:
            checkpoint = {'state_dict' : model.state_dict(),'optimizer' : optimizer.state_dict()}
            save_checkpoint(checkpoint)
   
    model.train()
    for data in train_loader:

        x_train_batch, y_train_batch = data[0].to(device, 
            dtype=torch.float), data[1].to(device, dtype=torch.float)  
        y_train_pred = model(x_train_batch)            # 1) Forward pass
        mse_train_loss = criterion(y_train_batch, y_train_pred, x_train_batch, mse) 
        mae_train_loss = criterion(y_train_batch, y_train_pred, x_train_batch, l1loss)  
        
        optimizer.zero_grad()                   
        mse_train_loss.backward()                        
        optimizer.step()                        
        
        mse_train_losses.append(mse_train_loss.item())
        mae_train_losses.append(mae_train_loss.item())
        N_train.append(len(x_train_batch))
        
                       
    test_pred=[] 
    test_real=[]
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            x_test_batch, y_test_batch = data[0].to(device, 
                dtype=torch.float), data[1].to(device, dtype=torch.float)

            y_test_pred = model(x_test_batch)
            mse_val_loss = criterion(y_test_batch, y_test_pred, x_test_batch, mse)
            mae_val_loss = criterion(y_test_batch, y_test_pred, x_test_batch, l1loss)
            
            mse_val_losses.append(mse_val_loss.item())
            mae_val_losses.append(mae_val_loss.item())
            N_test.append(len(x_test_batch))
            
            test_pred.append(y_test_pred)                        
            test_real.append(y_test_batch)

1 个答案:

答案 0 :(得分:0)

当你将它添加到列表中时,尝试像这样在末尾使用 .cpu():

test_pred.append(t_test_pred.cpu())