加速numpy kronecker产品

时间:2011-08-25 16:23:05

标签: python numpy

我正在开发我的第一个大型python项目。我有一个函数,其中包含以下代码:

            # EXPAND THE EXPECTED VALUE TO APPLY TO ALL STATES,
            # THEN UPDATE fullFnMat
            EV_subset_expand = np.kron(EV_subset, np.ones((nrows, 1)))
            fullFnMat[key] = staticMat[key] + EV_subset_expand                

在我的代码分析器中,看起来这个kronecker产品实际上占用了大量的时间。

Function                                                                                        was called by...
                                                                                                    ncalls  tottime  cumtime
/home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/BellmanEquation.py:17(bellmanFn)    <-      19   37.681   38.768  /home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/dpclient.py:467(solveTheModel)
{numpy.core.multiarray.concatenate}                                                             <-     342   27.319   27.319  /usr/lib/pymodules/python2.7/numpy/lib/shape_base.py:665(kron)
/home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/dpclient.py:467(solveTheModel)      <-       1   11.041   91.781  <string>:1(<module>)
{method 'argsort' of 'numpy.ndarray' objects}                                                   <-      19    7.692    7.692  /usr/lib/pymodules/python2.7/numpy/core/fromnumeric.py:597(argsort)
/usr/lib/pymodules/python2.7/numpy/core/numeric.py:789(outer)                                   <-     171    2.526    2.527  /usr/lib/pymodules/python2.7/numpy/lib/shape_base.py:665(kron)
{method 'max' of 'numpy.ndarray' objects}                                                       <-     209    2.034    2.034  /home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/dpclient.py:391(getValPolMatrices)

有没有办法在Numpy中获得更快的kronecker产品?它似乎不应该花费很长时间。

3 个答案:

答案 0 :(得分:7)

您当然可以查看np.kron的来源。它可以在numpy/lib/shape_base.py中找到,您可以看到是否有可以进行的改进或可以使其更有效的简化。或者,您可以使用Cython或其他一些低级语言绑定来编写自己的语言,以尝试获得更好的性能。

或者@matt建议以下内容可能本身更快:

import numpy as np
nrows = 10
a = np.arange(100).reshape(10,10)
b = np.tile(a,nrows).reshape(nrows*a.shape[0],-1) # equiv to np.kron(a,np.ones((nrows,1)))

或:

b = np.repeat(a,nrows*np.ones(a.shape[0],np.int),axis=0)

时序:

In [80]: %timeit np.tile(a,nrows).reshape(nrows*a.shape[0],-1)
10000 loops, best of 3: 25.5 us per loop

In [81]: %timeit np.kron(a,np.ones((nrows,1)))
10000 loops, best of 3: 117 us per loop

In [91]: %timeit np.repeat(a,nrows*np.ones(a.shape[0],np.int),0)
100000 loops, best of 3: 12.8 us per loop

在上面的示例中使用np.repeat表示大小的数组,可以提供非常好的10倍加速,这不是太糟糕。

答案 1 :(得分:2)

也许np.kron()正在分配内存然后你就把它扔掉了。请尝试使用np.tile()。我不知道是否会分配更多内存或在封面下播放索引技巧。如果您只将EV_subset乘以1,则实际上不需要调用np.kron()

答案 2 :(得分:1)

以下内容可能会有所帮助(在一般情况下,其中一个阵列不是&#39;那些&#39;)。 例子是两个形状(a,b,c)和(d,e,f)的阵列A,B; 根据需要进行概括。

通过单个&#39;乘法&#39;完成它。 op和快速重塑。

kprod = A[:,newaxis,:,newaxis,:,newaxis] * B[newaxis,:, newaxis,:, newaxis,:]
#
# kprod.shape = (a,d,b,e,c,f) now; is full outer product with desired arrangement
# in memory.
kprod.shape = (a*d,b*e,c*f)  # reshape 'in place' 

(也许这是kron(B,A)而不是kron(A,B);如果需要可以反转A&amp; B