我已经按照本文的说明编写了一个自定义损失函数:https://epubs.siam.org/doi/pdf/10.1137/1.9781611976236.18 但是,当我尝试使用此损失函数训练模型时,所有输出值在第一次损失后都变为 NaN。向后()操作。我的猜测是此代码中的某些内容导致渐变爆炸或丢失,但我不知道它在哪里/是什么。有人可以帮忙吗?一个最低限度的可重现示例如下:
import torch
from torch import optim
import torchvision
from torchvision import datasets
import timm # you may need to pip install this!!
### LOSS FUNCTION ###
def class_means(Z, Y):
'''Returns unique classes in batch and per class means'''
classes = torch.unique(Y)
means = None
for c in classes:
if means == None:
means = torch.mean(Z[Y == c], axis = 0)
else:
means = torch.vstack([means, torch.mean(Z[Y == c], axis = 0)])
return classes, means
def intra_spread(Z, Y, classes, means):
''' Takes the L2 norm of all outputs (Z) from their respective class means
and averages them'''
N = Z.shape[0]
intraSpread = 0
for i in range(classes.shape[0]):
intraSpread += torch.sqrt(torch.sum((means[i] - Z[Y == classes[i]]) ** 2))
return intraSpread / N
def similarity_matrix(mat):
'''Return the distances between all rows of the input matrix'''
r = torch.mm(mat, mat.t())
diag = r.diag().unsqueeze(0)
diag = diag.expand_as(r)
D = diag + diag.t() - 2*r
return D.sqrt()
def inter_separation(means):
'''Returns the distance between the two closest means in input means'''
return torch.min(similarity_matrix(means)[similarity_matrix(means) > 0])
def ii_loss(Z, Y):
'''Returns intraSpread - interSep'''
classes, means = class_means(Z, Y)
intraSpread = intra_spread(Z, Y, classes, means)
interSep = inter_separation(means)
return intraSpread - interSep
criterion = ii_loss # use ii_loss as the criterion
# Loading in CIFAR 10. Make a files folder in your working directory
batch_size_train = 32
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10('./files/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize((300, 300)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size_train, shuffle=True)
### Fitting Model ###
## Load model
device = 0 # or 0 for gpu
m = timm.create_model('efficientnet_b3', pretrained = True)
m.classifier = torch.nn.Linear(m.classifier.in_features, 10) # 10 classes in cifar10
m = m.to(device)
## Establish hyperparameters
learning_rate = 1e-4
momentum = 0.9
reg = 1e-4
epochs = 18
decay_epochs = 15
decay = 0.1
optimizer = optim.Adam(m.parameters(), lr = learning_rate)
## Train model
for i in range(epochs):
print("Epoch: " + str(i + 1))
for batch_idx, (inputs, targets) in enumerate(train_loader):
print(" " + str(batch_idx + 1) + "...", end = "")
inputs = inputs.to(device) # X_batch
targets = targets.to(device) # Y_batch
optimizer.zero_grad()
outputs = m(inputs) # Z_batch
loss = criterion(outputs, targets) # loss_batch
loss.backward()
optimizer.step()
print("DONE")