我正在ImageNet上训练CNN,并且看到损失函数出现一些奇怪的周期性振荡。曲线看起来通常是正确的,但是随着曲线变平,您可以看到清晰的振荡。我每100小批次捕获一次损失,并且由于我的批次大小为128,并且训练集为〜100k张图像,这意味着需要10个步骤来浏览数据。因此,我在这里看到的周期性是损耗在10步过程中增加,然后在新的时期开始时突然下降,然后逐渐增加,直到到达下一个时期为止。
我在下面显示了我的大部分训练代码(为了简洁起见,没有完整显示它)。我将在每个时期的开始以及每100个迷你批次的结尾将“ running_loss”重置为零。有人在这里有什么看法吗?我不知道我在计算/收集损失的方式上是否做错了什么,或者我的模型是否存在更根本的错误。
train_data = ImageDataset(train_file_list, transform=transform)
train_data_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4)
alexnet = AlexNetPyTorch(NUM_CLASSES)
alexnet = alexnet.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=alexnet.parameters(), lr=3.0e-4, weight_decay=5.0e-4)
history = {}
history['loss'] = []
history['val_loss'] = []
history['accuracy'] = []
history['val_accuracy'] = []
step_increment = 100
for epoch in range(30):
running_loss = 0.0
running_val_loss = 0.0
for step, data in enumerate(train_data_loader):
alexnet.train()
X_train, y_train = data
X_train = X_train.to(device)
y_train = y_train.to(device)
# forward + backward + optimize
y_pred = alexnet(X_train)
loss = loss_fn(y_pred, y_train)
# zero the parameter gradients
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if step % step_increment == step_increment-1: # print every 100 mini-batches
# Compute validation loss
with torch.no_grad():
alexnet.eval()
_, preds = torch.max(y_pred,1)
train_accuracy = torch.sum(preds == y_train).item()/len(y_train)
X_val, y_val = validation_data
X_val = X_val.to(device)
y_val = y_val.to(device)
y_pred = alexnet(X_val)
val_loss = loss_fn(y_pred, y_val).item()
_, preds = torch.max(y_pred,1)
val_accuracy = torch.sum(preds == y_val).item()/len(y_val)
history['loss'].append(np.mean(running_loss/step_increment))
history['accuracy'].append(train_accuracy)
history['val_accuracy'].append(val_accuracy)
history['val_loss'].append(val_loss)
print('[%d, %5d] loss: %.3f acc: %.3f val loss: %.3f val acc: %.3f' %
(epoch + 1, step+1, running_loss/step_increment, train_accuracy, val_loss, val_accuracy))
running_loss = 0.0