Python:向量化功能后,Autograd返回零梯度

时间:2019-10-18 16:04:33

标签: python vectorization autograd

这个问题历史悠久:

我有一个复杂的函数,我想计算其梯度,而且速度太快了。我决定使用autograd,效果很好。但是,我需要加快速度,因此我决定在autograd中使用jax函数,该函数可以使用GPU加速。

但是,jax在返回与autograd相同的答案的同时,却无法加快操作速度,甚至偶而使内核崩溃。通过jax论坛,我了解到问题是我的函数未向量化,即它使用for循环而不是向量运算来遍历数组(其原因是我最初有为cython编写了函数,该函数需要for循环)。因此,我对函数(及其辅助包装函数)进行了矢量化处理,并通过确保函数返回相同的值来验证过程中没有发生错误。

但是,使用此矢量化功能,autogradjax都返回零梯度。我猜想在包装函数中可能有一些小错误或不一致,但是我无法弄清楚是什么。

我在Google colab上发布了一个完整的工作示例,其中包含函数的矢量化版本和非矢量化版本: https://colab.research.google.com/drive/1VCPiRTLfxOflooDtTlqHI4VzYQkXzaUn

,并且还在jax论坛上发布了一个问题,但没有成功: https://github.com/google/jax/issues/1407

有人可以帮忙吗?

0 个答案:

没有答案