Python优化:如何加快矩阵逆运算?

时间:2019-01-30 14:38:07

标签: python numpy scipy

我的代码包含一个带有大量迭代的for循环。在循环中,我需要矩阵乘法和逆矩阵(通常是大小为12 x 12的矩阵)。我的循环需要运行120,000次,目前我的速度为14s,与MATLAB(1s)和FORTRAN(0.4s)相比,它的速度相对较高。以下是我要优化的功能:

def fixed_simulator(ref, xg, yg, dt, ndiv, ijk, nst, smx, skx, cdx, smy, sky, cdy):

    gamma = 0.5
    beta = 1.0/6.0

    knx = skx + (gamma/beta/dt)*cdx + (1.0/beta/np.power(dt,2))*smx

    dx1 = np.ones((nst,1), dtype=float)*0.0
    vx1 = np.ones((nst,1), dtype=float)*0.0
    px1 = np.diag(-1.0*smx).reshape(nst,1)*xg[0]
    ax1 = np.matmul(linalg.inv(smx), px1 - np.matmul(cdx, vx1) - np.matmul(skx, dx1))

    # I = np.ones((nst,1), dtype=float)

    dx2 = np.zeros((nst,1), dtype=float)
    vx2 = np.zeros((nst,1), dtype=float)
    px2 = np.zeros((nst,1), dtype=float)
    ax2 = np.zeros((nst,1), dtype=float)

    na1x = (1.0/beta/np.power(dt,2))*smx + (gamma/beta/dt)*cdx
    na2x = (1.0/beta/dt)*smx + (gamma/beta - 1.0)*cdx
    na3x = (1.0/2.0/beta - 1.0)*smx + (gamma*dt/2.0/beta - dt)*cdx

    print(len(xg))

# -----> Below is the loop that's taking long time.  

    for i in range(1,len(xg)):

        px2 = np.diag(-1.0*smx).reshape(nst,1)*xg[i]

        pcx1 = px2 + np.matmul(na1x, dx1) + np.matmul(na2x, vx1) + np.matmul(na3x, ax1)

        dx2 =  np.matmul(np.linalg.inv(smx), pcx1)

        vx2 = (gamma/beta/dt)*(dx2 - dx1) + (1.0 - gamma/beta)*vx1 + dt*(1.0 - gamma/2.0/beta)*ax1

        ax2 = np.matmul(np.linalg.inv(smx), px2 - np.matmul(cdx, vx2) - np.matmul(skx, dx2))

        dx1, vx1, px1, ax1 = dx2, vx2, px2, ax2 

大部分时间似乎都在计算逆和乘法部分。

系统中的Numpy配置:

blas_mkl_info:
  NOT AVAILABLE
blis_info:
  NOT AVAILABLE
openblas_info:
    library_dirs = ['C:\\projects\\numpy-wheels\\numpy\\build\\openblas']
    libraries = ['openblas']
    language = f77
    define_macros = [('HAVE_CBLAS', None)]
blas_opt_info:
    library_dirs = ['C:\\projects\\numpy-wheels\\numpy\\build\\openblas']
    libraries = ['openblas']
    language = f77
    define_macros = [('HAVE_CBLAS', None)]
lapack_mkl_info:
  NOT AVAILABLE
openblas_lapack_info:
    library_dirs = ['C:\\projects\\numpy-wheels\\numpy\\build\\openblas']
    libraries = ['openblas']
    language = f77
    define_macros = [('HAVE_CBLAS', None)]
lapack_opt_info:
    library_dirs = ['C:\\projects\\numpy-wheels\\numpy\\build\\openblas']
    libraries = ['openblas']
    language = f77
    define_macros = [('HAVE_CBLAS', None)]

cProfile结果

         2157895 function calls in 2.519 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    1.474    1.474    2.519    2.519 C:\Users\Naseef\OneDrive\04AllPhDPrograms\mhps\mhps\fixed.py:154(fixed_simulator)
   839163    0.556    0.000    0.556    0.000 {built-in method numpy.core.multiarray.matmul}
   119881    0.105    0.000    0.439    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\lib\twodim_base.py:197(diag)
   119881    0.083    0.000    0.256    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\core\fromnumeric.py:1294(diagonal)
   239762    0.049    0.000    0.107    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\core\numeric.py:504(asanyarray)
   119881    0.103    0.000    0.103    0.000 {method 'diagonal' of 'numpy.ndarray' objects}
   239763    0.058    0.000    0.058    0.000 {built-in method numpy.core.multiarray.array}
   119881    0.049    0.000    0.049    0.000 {method 'reshape' of 'numpy.ndarray' objects}
   119881    0.022    0.000    0.022    0.000 {built-in method builtins.isinstance}
   239764    0.019    0.000    0.019    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:468(inv)
        2    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\core\numeric.py:156(ones)
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:141(_commonType)
        2    0.000    0.000    0.000    0.000 {built-in method numpy.core.multiarray.empty}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        2    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\progressbar\utils.py:28(write)
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:108(_makearray)
        2    0.000    0.000    0.000    0.000 {built-in method numpy.core.multiarray.copyto}
        4    0.000    0.000    0.000    0.000 {built-in method numpy.core.multiarray.zeros}
        2    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\progressbar\bar.py:547(update)
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:126(_realType)
        1    0.000    0.000    0.000    0.000 {method 'astype' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\core\numeric.py:433(asarray)
        2    0.000    0.000    0.000    0.000 {method 'write' of '_io.StringIO' objects}
        2    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:113(isComplexType)
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:103(get_linalg_error_extobj)
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:200(_assertRankAtLeast2)
        1    0.000    0.000    0.000    0.000 C:\Users\Naseef\AppData\Local\Programs\Python\Python36-32\lib\site-packages\numpy\linalg\linalg.py:211(_assertNdSquareness)
        2    0.000    0.000    0.000    0.000 {built-in method time.perf_counter}
        3    0.000    0.000    0.000    0.000 {built-in method builtins.issubclass}
        1    0.000    0.000    0.000    0.000 {method '__array_prepare__' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}

1 个答案:

答案 0 :(得分:1)

在尝试加速numpy代码之前,通常需要执行两个步骤。

  1. 分析代码以找出花费最多的时间
  2. 构建几个测试用例,调用代码以验证优化后的版本仍然正常运行

测试用例应该简单,快速地运行,但是仍然可以反映真实数据。 您提供了概要分析,但是没有测试用例,因此接下来的工作将是一些未经测试的猜测工作。 查看运行时间最长的测试用例,很明显,运行时间来自循环,并且大部分来自矩阵运算。

119881 0.049 0.000 0.049 0.000 {'numpy.ndarray'对象的方法'重塑'} 239763 0.058 0.000 0.058 0.000 {内置方法numpy.core.multiarray.array} 119881 0.103 0.000 0.103 0.000 {'numpy.ndarray'对象的方法'对角线'} 239762 0.049 0.000 0.107 0.000 ... \ core \ numeric.py:504(asanyarray) 119881 0.083 0.000 0.256 0.000 ... \ core \ fromnumeric.py:1294(对角线) 119881 0.105 0.000 0.439 0.000 ... \ lib \ twodim_base.py:197(diag) 839163 0.556 0.000 0.556 0.000 {内置方法numpy.core.multiarray.matmul}

第一个奇怪的是,np.linalg.inv(smx)在慢速操作中没有出现。 我认为您误解了评论者的建议,并将其完全移出了主循环。 它仍应处于主循环中,但只能调用一次。

for i in range(1,len(xg)):
    ....
    smxinv = np.linalg.inv(smx) ## Calculate inverse once per loop
    dx2 =  np.matmul(smxinv, pcx1)
...
ax2 = np.matmul(smxinv, px2 - np.matmul(cdx, vx2) - np.matmul(skx, dx2))
...

最慢的操作是matmul。 这并不奇怪-在主循环中被调用了七次。 每个调用似乎都有唯一的参数,因此我看不出有任何简单的方法可以加快速度。 接下来是diagdiagonal。 这些将创建一个对角线数组,其中大多数条目为零,因此将创建内容移动到循环之外,并且仅更新非零条目应该可以提高速度。

##  Pre allocate px2 array (may not have a large effect)
px2 = np.diag(1).reshape(nst,1)
px2i = where(px2) ## Setup index of non-zero entries

for i in range(1,len(xg)):
    px2[px2i] = -smx*xg[i]  ## This should be equivalent
    ...

这也消除了重塑的要求。 您还可以预先计算一些常量,并避免每个循环进行一些计算, 但这可能不会对整体运行时间产生很大影响。

每个步骤都需要针对一个测试用例进行操作,以确保它们不会更改功能的行为,然后进行概要分析以查看提供了多少(如果有)改进。 我希望您能在4到5秒钟之内得到它,但Python无法与编译语言匹敌。