在Python中更快地卷积概率密度函数

时间:2015-03-06 14:45:44

标签: python numpy vectorization convolution probability-density

假设需要计算一般数量的离散概率密度函数的卷积。对于下面的示例,有四个分布采用具有指定概率的值0,1,2:

import numpy as np
pdfs = np.array([[0.6,0.3,0.1],[0.5,0.4,0.1],[0.3,0.7,0.0],[1.0,0.0,0.0]])

卷积可以这样找到:

pdf = pdfs[0]        
for i in range(1,pdfs.shape[0]):
    pdf = np.convolve(pdfs[i], pdf)

然后由

给出看到0,1,...,8的概率
array([ 0.09 ,  0.327,  0.342,  0.182,  0.052,  0.007,  0.   ,  0.   ,  0.   ])

这部分是我的代码中的瓶颈,似乎必须有一些东西可用于矢量化这个操作。有没有人建议让它更快?

或者,您可以使用

的解决方案
pdf1 = np.array([[0.6,0.3,0.1],[0.5,0.4,0.1]])
pdf2 = np.array([[0.3,0.7,0.0],[1.0,0.0,0.0]])
convolve(pd1,pd2) 

获得成对卷积

 array([[ 0.18,  0.51,  0.24,  0.07,  0.  ], 
        [ 0.5,  0.4,  0.1,  0. ,  0. ]])

也会有很大的帮助。

1 个答案:

答案 0 :(得分:14)

您可以使用快速傅里叶变换(FFT)有效地计算所有PDF的卷积:关键的事实是FFT of the convolution是各个概率密度函数的FFT的乘积。因此,转换每个PDF,将转换后的PDF相乘,然后执行逆变换。您需要将每个输入PDF用零填充到适当的长度,以避免环绕效应。

这应该是相当有效的:如果你有m个PDF,每个包含n条目,那么使用这种方法计算卷积的时间应该增长为(m^2)n log(mn)。时间由FFT控制,我们有效地计算m + 1个独立的FFT(m正向变换和一个逆变换),每个长度不大于{{1}的数组}。但与往常一样,如果你想要真正的时间,你应该描述。

这里有一些代码:

mn

将此应用于您的示例,这是我得到的:

import numpy.fft

def convolve_many(arrays):
    """
    Convolve a list of 1d float arrays together, using FFTs.
    The arrays need not have the same length, but each array should
    have length at least 1.

    """
    result_length = 1 + sum((len(array) - 1) for array in arrays)

    # Copy each array into a 2d array of the appropriate shape.
    rows = numpy.zeros((len(arrays), result_length))
    for i, array in enumerate(arrays):
        rows[i, :len(array)] = array

    # Transform, take the product, and do the inverse transform
    # to get the convolution.
    fft_of_rows = numpy.fft.fft(rows)
    fft_of_convolution = fft_of_rows.prod(axis=0)
    convolution = numpy.fft.ifft(fft_of_convolution)

    # Assuming real inputs, the imaginary part of the output can
    # be ignored.
    return convolution.real

这是基本想法。如果你想调整它,你也可以看一下numpy.fft.rfft(和它的逆,numpy.fft.irfft),它利用输入是真实的,以产生更紧凑的变换数组。您也可以通过用{0}填充>>> convolve_many([[0.6, 0.3, 0.1], [0.5, 0.4, 0.1], [0.3, 0.7], [1.0]]) array([ 0.09 , 0.327, 0.342, 0.182, 0.052, 0.007]) 数组来获得一些速度,以便总列数最适合执行FFT。 "最佳"的定义这取决于FFT的实现,但例如,2的幂是良好的目标。最后,如果所有输入数组具有相同的长度,则在创建rows时可以进行一些明显的简化。但是我会把这些潜在的改进留给你。