在火炬中找到黑森州期间的问题

时间:2020-11-05 11:37:09

标签: gradient loss-function hessian

我正在编写ElasticWeightConsolidation方法,为此,我需要计算Fisher矩阵。据我了解,费舍尔矩阵只是神经网络权重的似然性。有很好的功能,例如torch.autograd.functional.hessian(func,inputs,create_graph = False,strict = False)

所以我想计算粗麻布(损耗,重量)。 损失= torch.nn.CrossEntropyLoss() 我还准备了网络的权重,以使其具有较长的一维张量,从而可以简单地像这样使用粗麻布的对角元素:

def flat_param(model_param = yann_lecun.parameters()):
  ans_data = []
  ans_data = torch.tensor(ans_data, requires_grad=True)
  ans_data = ans_data.to(device)
  for p in model_param:
    temp_data = p.data.flatten()
    ans_data = torch.cat((ans_data,temp_data))
  return ans_data

ans = flat_param(yann_lecun.parameters())

然后,我尝试这样做:hessian(loss, inputs = ans),但问题是损失也有目标,但我不想计算它们的粗麻布。任务是mnist分类,因此目标是整数0 ... 9 如果我将y_train添加到类似 hessian(loss,inputs = (ans,y_train_01)的参数中 令人沮丧的是“不能从整数中提取梯度”。我也尝试制作y_train_01.requires_grad = False ,但没有帮助。我知道损失也取决于y_train_01,但是在我的情况下,有什么方法可以确定目标是常量吗? 预先感谢!

1 个答案:

答案 0 :(得分:0)

您可以创建一个新的“包装器”功能,其中目标是固定的:

def original_model(features, targets):
    ...
    return ...

def feature_differentiable_model(features):
    fixed_targets = ...
    return original_model(features, fixed_targets)

然后调用:

hessian(feature_differentiable_model, features_vals)

由此产生的二阶偏导数将等价于位置 (features_vals, fixed_targets) 处的完整 Hessian 乘积的类似部分。