这个问题历史悠久:
我有一个复杂的函数,我想计算其梯度,而且速度太快了。我决定使用autograd
,效果很好。但是,我需要加快速度,因此我决定在autograd
中使用jax
函数,该函数可以使用GPU加速。
但是,jax
在返回与autograd
相同的答案的同时,却无法加快操作速度,甚至偶而使内核崩溃。通过jax
论坛,我了解到问题是我的函数未向量化,即它使用for
循环而不是向量运算来遍历数组(其原因是我最初有为cython
编写了函数,该函数需要for循环)。因此,我对函数(及其辅助包装函数)进行了矢量化处理,并通过确保函数返回相同的值来验证过程中没有发生错误。
但是,使用此矢量化功能,autograd
和jax
都返回零梯度。我猜想在包装函数中可能有一些小错误或不一致,但是我无法弄清楚是什么。
我在Google colab上发布了一个完整的工作示例,其中包含函数的矢量化版本和非矢量化版本: https://colab.research.google.com/drive/1VCPiRTLfxOflooDtTlqHI4VzYQkXzaUn
,并且还在jax
论坛上发布了一个问题,但没有成功:
https://github.com/google/jax/issues/1407
有人可以帮忙吗?