联合训练两个网络时 pytorch 的 Cuda 内存不足问题

时间:2021-06-02 07:23:04

标签: pytorch

我尝试联合训练两个 DNN,模型经过训练并在每 5 个 epochs 后进入验证阶段,问题是在 5 个 epochs 之后还可以并且内存没有问题,但是在 10 个 epochs 之后模型抱怨库达内存。任何帮助解决内存问题。

class Trainer(BaseTrainer):
    def __init__(self, config, resume: bool, model, loss_function, optimizer, train_dataloader, validation_dataloader):
        super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
        self.train_data_loader = train_dataloader
        self.validation_data_loader = validation_dataloader
        self.model = self.model.double()
        

    def _train_epoch(self, epoch):
        #loss_total = 0.0
        for i, (mixture, clean, name, label) in enumerate(self.train_data_loader):
            mixture = mixture.to(self.device, dtype=torch.double)
            clean = clean.to(self.device, dtype=torch.double)
            enhanced = self.model(mixture).to(self.device)
            front_loss = self.loss_function(clean, enhanced)
            
            front_loss.backward(retain_graph=True)
            
            torch.cuda.empty_cache()
            model_back.train()
            y = model_back(enhanced.double().to(device2))
            back_loss = backend_loss(y[0], label[0].to(device2))
            print("Iteration %d in epoch%d--> loss = %f" % (i, epoch, back_loss.item()), end='\r')
            
            #optimizer_back.zero_grad()
            back_loss.backward(retain_graph=True)
            self.optimizer.zero_grad()
            self.optimizer.step()
            #optimizer_back.step()
            torch.cuda.empty_cache()
            #loss_total += float(front_loss.item() + back_loss.item())
        dl_len = len(self.train_data_loader)
        #self.writer.add_scalar(f"Train/Loss", loss_total / dl_len, epoch)

    @torch.no_grad()
    def _validation_epoch(self, epoch):
            #visualize_audio_limit = self.validation_custom_config["visualize_audio_limit"]
            #visualize_waveform_limit = self.validation_custom_config["visualize_waveform_limit"]
            #visualize_spectrogram_limit = self.validation_custom_config["visualize_spectrogram_limit"]

            sample_length = self.validation_custom_config["sample_length"]

            stoi_c_n = []  # clean and noisy
            stoi_c_e = []  # clean and enhanced
            stoi_e_n = []
            pesq_c_n = []
            pesq_c_e = []
            pesq_e_n = []
            correct = []

            for i, (mixture, clean, name, label) in enumerate(self.validation_data_loader):
                #assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
                name = name[0]
                padded_length = 0

                mixture = mixture.to(self.device)

                if mixture.size(-1) % sample_length != 0:
                    padded_length = sample_length - (mixture.size(-1) % sample_length)
                    mixture = torch.cat([mixture, torch.zeros(1, 1, padded_length, device=self.device)], dim=-1)

                assert mixture.size(-1) % sample_length == 0 and mixture.dim() == 3
                mixture_chunks = list(torch.split(mixture, sample_length, dim=-1))

                enhanced_chunks = []
                for chunk in mixture_chunks:
                    enhanced_chunks.append(self.model(chunk.double()).detach().cpu())

                enhanced = torch.cat(enhanced_chunks, dim=-1)  # [1, 1, T]
                enhanced = enhanced.to(self.device)
                #print(enhanced)
                if padded_length != 0:
                    enhanced = enhanced[:, :, :-padded_length]
                    mixture = mixture[:, :, :-padded_length]

                torch.cuda.empty_cache()
                model_back.eval()
                
                y_pred = model_back(enhanced.double().to(self.device))
                pred = torch.argmax(y_pred[0].detach().cpu(), dim=1)
                intent_pred = pred
                correct.append((intent_pred == label[0]).float())
                torch.cuda.empty_cache()
            acc = np.mean(np.hstack(correct))
            intent_acc = acc
            iter_acc = '\n iteration %d epoch %d -->' %(i, epoch)
            print(iter_acc, acc, best_accuracy)
            if intent_acc > best_accuracy:
                improved_accuracy = 'Current accuracy {}, {}'.format(intent_acc, best_accuracy)
                print(improved_accuracy)
                torch.save(model_back.state_dict(), '/home/mnabih/jt/best_model.pkl')

我尝试过的一些解决方案是使用垃圾收集器工具(gc 库)并在训练和验证中随机添加 torch.Cuda.empty_cache()。

1 个答案:

答案 0 :(得分:0)

不要在第二次向后传球时使用 retain_graph = True。此标志使您的代码可以永久存储每个批次的计算图。

除了最后一次调用 retain_graph 之外的所有调用都应该使用 backward(),该调用通过相同的变量/参数进行反向传播。