EWC费希尔矩阵

时间:2019-12-21 18:54:44

标签: neural-network pytorch

我要重新实现此paper中概述的弹性重量合并(EWC)吗?作为参考,我还使用了repo(另一个实现)。

我的模型/想法非常简单。训练网络进行位操作AND(例如1 && 0 = 0),然后使用EWC,训练它使用OR(例如1 || 0 = 1)。我有三个输入:bit1,bit2和运算(0代表AND,1代表OR)和一个输出神经元-运算的输出。例如,如果我有0 1 0,则地面真相应为0。

但是,问题出在计算EWC损失时。

def penalty(self, model: nn.Module):
    loss = 0
    for n, p in model.named_parameters():
        _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
        loss += _loss.sum()
    return loss

我有两个问题:

  • 当前均值(p)和旧均值(self._means[n])始终相同,导致乘以0,从而完全抵消了EWC。
  • 由于我只有一个输出神经元,因此费舍尔矩阵的计算与回购略有不同。我写的那个似乎是错误的。有什么想法吗?

我通过EWC模型的初始化方法初始化self._means[n]self._precision_matrices(费舍尔矩阵):

class EWC(object):
def __init__(self, model: nn.Module, dataset: list, device='cpu'):

    self.model = model
    self.dataset = dataset
    self.device = device

    self._means = {}
    self._precision_matrices = self._diag_fisher()

    for n, p in self.model.named_parameters():
        self._means[n] = p.data.clone()

def _diag_fisher(self):
    precision_matrices = {}

    # Set it to zero
    for n, p in self.model.named_parameters():
        params = p.clone().data.zero_()
        precision_matrices[n] = params

    self.model.eval()

    for input in self.dataset:
        input = input.to(self.device)

        self.model.zero_grad()

        output = self.model(input)
        label = torch.sigmoid(output).round()
        loss = F.binary_cross_entropy_with_logits(output, label)
        # loss = F.nll_loss(F.log_softmax(output, dim=1), label)
        loss.backward()

        for n, p in self.model.named_parameters():
            precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)

    precision_matrices = {n: p for n, p in precision_matrices.items()}
    return precision_matrices

这是实际的培训:

# Train the model EWC
for epoch in tqdm(range(EPOCS)):

    # Get the loss
    ls = ewc_train(model, opt, loss_func, dataloader[task], EWC(model, old_tasks), importance, device)

def ewc_train(model: nn.Module, opt: torch.optim, loss_func:torch.nn, data_loader: torch.utils.data.DataLoader, ewc: EWC, importance: float, device):
    epoch_loss = 0

    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.to(device).long()
        labels = labels.to(device).float()

        opt.zero_grad()

        output = model(inputs)
        loss = loss_func(output.view(-1), labels) + importance * ewc.penalty(model)
        loss.backward()
        opt.step()

        epoch_loss += loss.item()

    return loss

注意:我正在使用的损失函数为nn.BCEWithLogitsLoss(),优化函数为SGD(params=model.parameters(), lr=0.001)

0 个答案:

没有答案