Python中的多精度浮点矢量化计算

时间:2019-03-23 22:22:00

标签: python gmpy

我正在尝试对Blahut-Arimoto算法进行编码,以优化通道X-> Y的速率失真。该算法使用Lagrange乘数beta,范围从很小到接近无穷大。该解决方案采用beta的指数,因此,如果beta非常大,将超过numpy中浮点数的精度。

我了解numpy数组与多个精度库不兼容,因此我需要使用类似于gmpy2的库/包装器。但是,这种方法效率低下,因此需要检测何时有必要,而不是使用多重精度作为默认值。

我已经能够使用np.errstate来检测无效值。下一步是利用gmpy2来计算指数。参数是numpy数组。有没有办法并行计算?

此外,我尝试使用gmpy2,但是我没有得到一个明智的答案。我只是不确定它是如何工作的,还没有找到说明问题的文档。

我没有设置特定的解决方案,只是选择了gmpy2,因为它似乎是C库mpfr的最新包装器。

def blahut_step(p_x, betaD):
    """
    Rate Distortion optimisation for communicating signal X -> Y
    using Lagrange Multiplier beta and distortion matrix D
    multi precision floating arithmetic may be required 
        p_x_next = p_k * exp[betaD].normalised

    BigFloat,gmpy2 - python wrapper for gnu mpfr library

    Params:
        p_x:    current estimate for distribution
        betaD:  beta * distortion matrix (nX, nY)

    Returns:
        next estimate of p_x
    """

    with np.errstate(invalid='raise'):
        try:
            result = np.multiply(p_x, np.exp(betaD))
            return result / np.sum(result, axis=1)[:, None] # normalise
        except FloatingPointError as err:
            # use multiple precision library such as gmpy2
            print("multiple precision required: {0}".format(err))
            exit(0)

感谢您的帮助,即使只是使用gmpy2将我指向具有多个精度的矩阵计算的方向。

0 个答案:

没有答案