如何解决 PyTorch 自定义损失函数中的梯度爆炸/损失问题?

时间:2021-03-16 23:31:33

标签: optimization deep-learning pytorch autograd

我已经按照本文的说明编写了一个自定义损失函数:https://epubs.siam.org/doi/pdf/10.1137/1.9781611976236.18 但是,当我尝试使用此损失函数训练模型时,所有输出值在第一次损失后都变为 NaN。向后()操作。我的猜测是此代码中的某些内容导致渐变爆炸或丢失,但我不知道它在哪里/是什么。有人可以帮忙吗?一个最低限度的可重现示例如下:

import torch
from torch import optim
import torchvision
from torchvision import datasets
import timm # you may need to pip install this!!

### LOSS FUNCTION ###

def class_means(Z, Y):
    '''Returns unique classes in batch and per class means'''
    classes = torch.unique(Y)
    means = None
    for c in classes:
        if means == None:
            means = torch.mean(Z[Y == c], axis = 0)
        else:
            means = torch.vstack([means, torch.mean(Z[Y == c], axis = 0)])
    return classes, means


def intra_spread(Z, Y, classes, means):
    ''' Takes the L2 norm of all outputs (Z) from their respective class means
    and averages them'''
    N = Z.shape[0]
    intraSpread = 0
    for i in range(classes.shape[0]):
        intraSpread += torch.sqrt(torch.sum((means[i] - Z[Y == classes[i]]) ** 2))
    return intraSpread / N

def similarity_matrix(mat):
    '''Return the distances between all rows of the input matrix'''
    r = torch.mm(mat, mat.t())
    diag = r.diag().unsqueeze(0)
    diag = diag.expand_as(r)
    D = diag + diag.t() - 2*r
    return D.sqrt()

def inter_separation(means):
    '''Returns the distance between the two closest means in input means'''
    return torch.min(similarity_matrix(means)[similarity_matrix(means) > 0])

def ii_loss(Z, Y):
    '''Returns intraSpread - interSep'''
    classes, means = class_means(Z, Y)
    intraSpread = intra_spread(Z, Y, classes, means)
    interSep = inter_separation(means)
    return intraSpread - interSep

criterion = ii_loss # use ii_loss as the criterion

# Loading in CIFAR 10. Make a files folder in your working directory
batch_size_train = 32
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10('./files/', train=True, download=True,
                        transform=torchvision.transforms.Compose([
                           torchvision.transforms.Resize((300, 300)),
                           torchvision.transforms.ToTensor(),
                           torchvision.transforms.Normalize(
                             (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                         ])),
  batch_size=batch_size_train, shuffle=True)

### Fitting Model ###
## Load model
device = 0 # or 0 for gpu
m = timm.create_model('efficientnet_b3', pretrained = True)
m.classifier = torch.nn.Linear(m.classifier.in_features, 10) # 10 classes in cifar10
m = m.to(device)

## Establish hyperparameters
learning_rate = 1e-4
momentum = 0.9
reg = 1e-4
epochs = 18
decay_epochs = 15
decay = 0.1
optimizer = optim.Adam(m.parameters(), lr = learning_rate)

## Train model
for i in range(epochs):
    print("Epoch: " + str(i + 1))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print("  " + str(batch_idx + 1) + "...", end = "")
        inputs = inputs.to(device) # X_batch
        targets = targets.to(device) # Y_batch
        optimizer.zero_grad()
        outputs = m(inputs) # Z_batch
        loss = criterion(outputs, targets) # loss_batch
        loss.backward()
        optimizer.step()
        print("DONE")

0 个答案:

没有答案