使用Tensorflow / PyTorch加速自定义功能的最小化

时间:2019-07-06 22:54:14

标签: python tensorflow optimization pytorch minimization

我做了很多仿真,为此,我经常需要最小化复杂的用户定义函数,为此,我通常使用numpyscipy.optimize.minimize()。但是,这样做的问题是我需要明确地写下一个梯度函数,有时很难找到它/不可能。对于大维向量,由scipy计算的数值导数非常昂贵。

因此,我正在尝试切换到TensorflowPyTorch来利用它们的自动区分功能,并能够自由地利用GPU。让我举一个明确的函数示例,该函数的导数编写起来有些复杂(将需要很多链规则),因此对于TensorflowPyTorch来说似乎已经成熟了-计算之间的二面角在3d空间中由四个点形成的两个三角形:

def dihedralAngle(xyz):
## calculate dihedral angle between 4 nodes

    p1, p2, p3, p4 = 0, 1, 2, 3

    ## get unit normal vectors
    N1 = np.cross(xyz[p1]-xyz[p3] , xyz[p2]-xyz[p3])
    N2 = - np.cross(xyz[p1]-xyz[p4] , xyz[p2]-xyz[p4])
    n1, n2 = N1 / np.linalg.norm(N1), N2 / np.linalg.norm(N2) 

    angle = np.arccos(np.dot(n1, n2))

    return angle

xyz1 = np.array([[0.2       , 0.        , 0.        ],
       [0.198358  , 0.02557543, 0.        ],
       [0.19345897, 0.05073092, 0.        ],
       [0.18538335, 0.0750534 , 0.        ]]) # or any (4,3) ndarray

print(dihedralAngle(xyz1)) >> 3.141

我可以很容易地使用scipy.optimize.minimize()最小化它,我应该得到0。对于这么小的函数,我真的不需要渐变(显式或数值)。但是,如果我希望遍历许多节点并最小化某些依赖于所有二面角的函数,那么开销会更高吗?

那么我的问题-

  1. 如何使用TensorFlowPyTorch来实现此最小化问题?既适用于单个二面角,又适用于此类角的列表(即,我们需要考虑遍历列表)。
  2. 此外,如果需要,我是否可以使用自动微分来获得渐变并重新插入scipy.optimize.minimize()?例如,scipy.optimize.minimize()可以轻松地限制和约束,这在Tensorflow或PyToch优化模块中没有注意到。

2 个答案:

答案 0 :(得分:1)

这是一个解决方案,它使用 Torch 自动计算梯度,然后使用我编写的库 autograd-minimize 使用 scipy 的最小化器。优于 SGD 的优点是估计精度更高(使用二阶方法)。它可能相当于使用来自 Torch 的 LBFGS:

import numpy as np
import torch
from autograd_minimize import minimize


def dihedralAngle(xyz):
## calculate dihedral angle between 4 nodes

    p1, p2, p3, p4 = 0, 1, 2, 3

    ## get unit normal vectors
    N1 = np.cross(xyz[p1]-xyz[p3] , xyz[p2]-xyz[p3])
    N2 = - np.cross(xyz[p1]-xyz[p4] , xyz[p2]-xyz[p4])
    n1, n2 = N1 / np.linalg.norm(N1), N2 / np.linalg.norm(N2) 

    angle = np.arccos(np.dot(n1, n2))

    return angle
def compute_angle(p1, p2):
    # inner_product = torch.dot(p1, p2)
    inner_product = (p1*p2).sum(-1)
    p1_norm = torch.linalg.norm(p1, axis=-1)
    p2_norm = torch.linalg.norm(p2, axis=-1)
    cos = inner_product / (p1_norm * p2_norm)
    cos = torch.clamp(cos, -0.99999, 0.99999)
    angle = torch.acos(cos)
    return angle

def compute_dihedral(v1,v2,v3,v4):
    ab = v1 - v2
    cb = v3 - v2
    db = v4 - v3
    u = torch.cross(ab, cb)
    v = torch.cross(db, cb)
    w = torch.cross(u, v)
    angle = compute_angle(u, v)
    angle = torch.where(compute_angle(cb, w) > 1, -angle, angle)

    return angle

def loss_func(v1,v2,v3,v4):
    return ((compute_dihedral(v1,v2,v3,v4)+2)**2).mean()


x0=[np.array([-17.0490,   5.9270,  21.5340]),
    np.array([-0.1608,  0.0600, -0.0371]),
    np.array([-0.2000,  0.0007, -0.0927]),
    np.array([-0.1423,  0.0197, -0.0727])]

res = minimize(loss_func, x0, backend='torch')

print(compute_dihedral(*[torch.tensor(v) for v in res.x])) 

答案 1 :(得分:0)

我正在做同样的事情。 这是我得到的。

def compute_angle(p1, p2):
    # inner_product = torch.dot(p1, p2)
    inner_product = (p1*p2).sum(-1)
    p1_norm = torch.linalg.norm(p1, axis=-1)
    p2_norm = torch.linalg.norm(p2, axis=-1)
    cos = inner_product / (p1_norm * p2_norm)
    cos = torch.clamp(cos, -0.99999, 0.99999)
    angle = torch.acos(cos)
    return angle
def compute_dihedral(v1,v2,v3,v4):
    ab = v1 - v2
    cb = v3 - v2
    db = v4 - v3
    u = torch.cross(ab, cb)
    v = torch.cross(db, cb)
    w = torch.cross(u, v)
    angle = compute_angle(u, v)
    # angle = torch.where(compute_angle(cb, w) > 0.001, -angle, angle)
    angle = torch.where(compute_angle(cb, w) > 1, -angle, angle)
#     try:
#         if compute_angle(cb, w) > 0.001:
#             angle = -angle
#     except ZeroDivisionError:
#         # dihedral=pi
#         pass
    return angle

v1 = torch.tensor([-17.0490,   5.9270,  21.5340], requires_grad=True)
v2 = torch.tensor([-0.1608,  0.0600, -0.0371], requires_grad=True)
v3 = torch.tensor([-0.2000,  0.0007, -0.0927], requires_grad=True)
v4 = torch.tensor([-0.1423,  0.0197, -0.0727], requires_grad=True)

dihedral = compute_dihedral(v1,v2,v3,v4)
target_dihedral = -2

print(dihedral)   # should print -2.6387


for i in range(100):
    dihedral = compute_dihedral(v1,v2,v3,v4)
    loss = (dihedral - target_dihedral)**2
    loss.backward()
    learning_rate = 0.001
    with torch.no_grad():
        v1 -= learning_rate * v1.grad
        v2 -= learning_rate * v2.grad
        v3 -= learning_rate * v3.grad
        v4 -= learning_rate * v4.grad

        # Manually zero the gradients after updating weights
        v1.grad = None
        v2.grad = None
        v3.grad = None
        v4.grad = None
print(compute_dihedral(v1,v2,v3,v4))   # should print -2