训练深度神经网络时更新值

时间:2021-06-20 21:18:05

标签: python python-3.x

我正在使用以下训练函数训练模型:

   def _train_epoch(self, epoch):
       loss_total = 0.0
       front_total = 0.0
       back_total = 0.0
       corr = []
       best_acc = 0.0
       for i, (mixture, clean, name, label) in enumerate(self.train_data_loader):
           mixture = mixture.to(self.device, dtype=torch.float)
           clean = clean.to(self.device, dtype=torch.float)
           if i % 10 == 0:
               enhanced = self.model(mixture).to(self.device)
               front_loss = self.loss_function(clean, enhanced)
               model_back.train()
               y = model_back(enhanced.float().to(device2))
               back_loss = self.loss_function2(y[0], label[0].to(device2))
               loss = front_loss + back_loss
               print("Iteration %d in epoch%d--> front_loss = %f  back_loss = %f loss = %f " % (i, epoch, front_loss.item(), back_loss.item() ,loss.item()), end='\r')
               loss_total += loss.item()
               front_total += front_loss.item()
               back_total += back_loss.item()
               loss.backward()
               self.optimizer.step()
               optimizer_back.step()
               self.optimizer.zero_grad()
               optimizer_back.zero_grad()
               p = torch.argmax(y[0].detach().cpu(), dim=1)
               intent_p = p
               corr.append((intent_p == label[0]).float())
               #torch.cuda.empty_cache()
       ac = np.mean(np.hstack(corr))
       intent_ac = ac
       iter_ac = '\n iteration %d epoch %d -->' %(i, epoch)
       print(iter_ac, ac, best_acc)
       if intent_ac > best_acc:
           improved_acc = 'Current accuracy {}, {}'.format(intent_ac, best_acc)
           best_acc = intent_ac
           print(improved_acc)

我有一个基本的python问题,但很抱歉我无法解决它,问题是best_acc始终为0并且没有更新,而我正在设置条件以在intent_ac更好时更新其值如图所示,比 best_acc:并不是说 intent_ac 正在改进,因此应该更新 best_acc。 ============== 1757 纪元 ==============

epoch1757 迭代 5990--> front_loss = 0.008438 back_loss = 1.030782 loss = 1.039219 迭代 5999 纪元 1757 --> 0.8016667 0.0

电流精度 0.8016666769981384, 0.0

============== 1758 纪元 ==============

epoch1758 迭代 5990--> front_loss = 0.013248 back_loss = 1.306771 loss = 1.320019 迭代 5999 纪元 1758 --> 0.81 0.0

电流精度 0.8100000023841858, 0.0

============== 1759 纪元 ==============

epoch1759 迭代 5990--> front_loss = 0.008453 back_loss = 1.679812 loss = 1.688265 迭代 5999 纪元 1759 --> 0.81 0.0 电流精度 0.8100000023841858, 0.0

0 个答案:

没有答案