损失函数的Pytorch预计算部分

时间:2020-08-01 12:33:42

标签: python pytorch

我正在尝试使用pytorch来计算复杂函数,但仍然需要根据输入来改变输出的梯度。例如:

a=torch.tensor([1,2,3], dtype=torch.float32, requires_grad=True)
b=torch.tensor([3,2,1], dtype=torch.float32, requires_grad=True)
v=torch.tensor([0.2], dtype=torch.float32, requires_grad=True)

# here precalc represents some (fairly expensive) sequence of operations
precalc = a.dot(b)+a*b+a*a+b*b

def calc(precalc, v):
    z=torch.randn(3,1000)
    batch=v*precalc.matmul(z)
    return torch.relu(batch).mean()

因此,预计算张量是a和b的固定函数(但计算起来很昂贵)。

当我第一次以x=calc(precalc, v)然后以x.backward()的形式调用calc时,我得到了a,b和v的正确梯度。但是,如果我第二次调用calc,即{ 1}},然后是x2=calc(precalc, v),我得到了pytorch错误,该错误是我第二次向后浏览该图,应该保留该图。

理想情况下,我希望释放图形,因为precalc的计算成本很高,但内存较小,而calc本身更易于计算,但会占用大量内存。

有没有办法做到这一点?

谢谢。

0 个答案:

没有答案