如何使用autograd查找MIN / MAX点

时间:2019-04-10 04:20:36

标签: python pytorch autograd

假设我们有一个简单的函数y = sin(x ** 2),如何使用autograd查找所有一阶导数为0的X:s?

1 个答案:

答案 0 :(得分:0)

下面的代码可以找到一阶导数为零的点。但是,根据随机初始化,只会找到一个点。如果要查找所有点,则可以尝试在某个所需的网格上迭代许多随机初始化。

import torch 
import numpy as np
# initialization
x = torch.tensor(np.random.rand(1)).requires_grad_(True)

while (x.grad is None or torch.abs(x.grad)>0.01):
    if (x.grad is not None):
        # zero grads
        x.grad.data.zero_()
    # compute fn
    y = torch.sin(x**2)
    # compute grads
    y.backward()
    # move in direction of / opposite to grads
    x.data = x.data - 0.01*x.grad.data
    # use below line to move uphill 
    # x.data = x.data + 0.01*x.grad.data

print(x)
print(y)
print(x.grad)

另请参见how to apply gradients manually in pytorch