我要重新实现此paper中概述的弹性重量合并(EWC)吗?作为参考,我还使用了repo(另一个实现)。
我的模型/想法非常简单。训练网络进行位操作AND(例如1 && 0 = 0),然后使用EWC,训练它使用OR(例如1 || 0 = 1)。我有三个输入:bit1,bit2和运算(0代表AND,1代表OR)和一个输出神经元-运算的输出。例如,如果我有0 1 0,则地面真相应为0。
但是,问题出在计算EWC损失时。
def penalty(self, model: nn.Module):
loss = 0
for n, p in model.named_parameters():
_loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
loss += _loss.sum()
return loss
我有两个问题:
p
)和旧均值(self._means[n]
)始终相同,导致乘以0,从而完全抵消了EWC。我通过EWC模型的初始化方法初始化self._means[n]
和self._precision_matrices
(费舍尔矩阵):
class EWC(object):
def __init__(self, model: nn.Module, dataset: list, device='cpu'):
self.model = model
self.dataset = dataset
self.device = device
self._means = {}
self._precision_matrices = self._diag_fisher()
for n, p in self.model.named_parameters():
self._means[n] = p.data.clone()
def _diag_fisher(self):
precision_matrices = {}
# Set it to zero
for n, p in self.model.named_parameters():
params = p.clone().data.zero_()
precision_matrices[n] = params
self.model.eval()
for input in self.dataset:
input = input.to(self.device)
self.model.zero_grad()
output = self.model(input)
label = torch.sigmoid(output).round()
loss = F.binary_cross_entropy_with_logits(output, label)
# loss = F.nll_loss(F.log_softmax(output, dim=1), label)
loss.backward()
for n, p in self.model.named_parameters():
precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)
precision_matrices = {n: p for n, p in precision_matrices.items()}
return precision_matrices
这是实际的培训:
# Train the model EWC
for epoch in tqdm(range(EPOCS)):
# Get the loss
ls = ewc_train(model, opt, loss_func, dataloader[task], EWC(model, old_tasks), importance, device)
def ewc_train(model: nn.Module, opt: torch.optim, loss_func:torch.nn, data_loader: torch.utils.data.DataLoader, ewc: EWC, importance: float, device):
epoch_loss = 0
for i, (inputs, labels) in enumerate(data_loader):
inputs = inputs.to(device).long()
labels = labels.to(device).float()
opt.zero_grad()
output = model(inputs)
loss = loss_func(output.view(-1), labels) + importance * ewc.penalty(model)
loss.backward()
opt.step()
epoch_loss += loss.item()
return loss
注意:我正在使用的损失函数为nn.BCEWithLogitsLoss()
,优化函数为SGD(params=model.parameters(), lr=0.001)
。