pytorch中的高阶梯度

时间:2018-05-14 03:46:11

标签: python pytorch gradients autograd

我在pytorch中实现了以下雅可比函数。除非我犯了一个错误,否则它会计算任何张量的雅各比矩阵。任何维度输入:

import torch
import torch.autograd as ag

def nd_range(stop, dims = None):
    if dims == None:
        dims = len(stop)
    if not dims:
        yield ()
        return
    for outer in nd_range(stop, dims - 1):
        for inner in range(stop[dims - 1]):
            yield outer + (inner,)


def full_jacobian(f, wrt):    
    f_shape = list(f.size())
    wrt_shape = list(wrt.size())
    fs = []


    f_range = nd_range(f_shape)
    wrt_range = nd_range(wrt_shape)

    for f_ind in f_range:
        grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        for i in range(len(f_shape)):
            grad = grad.unsqueeze(0)
        fs.append(grad)

fj = torch.cat(fs, dim=0)
fj = fj.view(f_shape + wrt_shape)
return fj

除此之外,我还试图实现一个递归函数来计算n阶导数:

def nth_derivative(f, wrt, n):
    if n == 1:
        return full_jacobian(f, wrt)
    else:        
        deriv = nth_derivative(f, wrt, n-1)
        return full_jacobian(deriv, wrt)

我做了一个简单的测试:

op = torch.ger(s, s)
deep_deriv = nth_derivative(op, s, 5)

不幸的是,这成功地让我成为了Hessian ......但没有更高阶的衍生品。我知道很多高阶导数应该是0,但我更喜欢pytorch可以分析计算它。

一个修复方法是将渐变计算更改为:

try:
            grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        except:
            grad = torch.zeros_like(wrt)

这是接受的正确方法吗?或者有更好的选择吗?或者我是否有理由认为我的问题完全错误?

2 个答案:

答案 0 :(得分:6)

您可以迭代调用grad function

import torch
from torch.autograd import grad

def nth_derivative(f, wrt, n):

    for i in range(n):

        grads = grad(f, wrt, create_graph=True)[0]
        f = grads.sum()

    return grads

x = torch.arange(4, requires_grad=True).reshape(2, 2)
loss = (x ** 4).sum()

print(nth_derivative(f=loss, wrt=x, n=3))

输出

tensor([[  0.,  24.],
        [ 48.,  72.]])

答案 1 :(得分:0)

对于二阶导数,可以使用PyTorch的{​​{1}}函数:

hessian

对于高阶导数,您可以在保持计算图的同时重复调用 torch.autograd.functional.hessian() jacobian

<块引用>

grad (bool, optional) – 如果 create_graph,将构建导数图,允许计算高阶导数产品。