闪电战代码产生不同的输出

时间:2013-04-26 07:21:09

标签: python numpy scipy

我想使用weave.blitz来改善以下numpy代码的性能:

def fastIteration(self):
    g = self.grid
    nx,ny = g.ux.shape

    uxold = g.old_ux
    ux = g.ux
    ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])

    g.setBC()
    g.old_ux = ux.copy()

在此代码中,g是计算网格。其中包括两个不同的字段ux和uxold。旧版仅用于临时存储变量。在完整的代码中,大约95%的运行时花费在fastIteration方法中,因此即使简单的性能提升也会减少执行此代码所花费的时间。

numpy方法的输出看起来好像:

numpy result

由于此代码是我的瓶颈,我想通过使用编织闪电来提高速度。此方法如下所示:

def blitzIteration(self):
    ### does not work correct so far
    g = self.grid
    nx,ny = g.ux.shape

    uxold = g.old_ux
    ux = g.ux
    expr = "ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])"
    weave.blitz(expr, check_size=0)
    g.setBC()
    g.old_ux = ux.copy()

然而,这不会产生正确的输出: output for blitz code

1 个答案:

答案 0 :(得分:2)

它看起来像weave.blitz中的一个错误(转载,归档和fixed。有关于那里的实际错误的更多信息)。

我认为编写0:而不是较短的:来获得一个完整的切片是很奇怪的,所以我替换了所有这些切片并且vo,它起作用了。

我真的不知道错误的位置,但weave.blitz生成的expr_code略有不同:

  • 使用0:

    ipdb> expr_code
    'ux_blitz_buggy(blitz::Range(0,_end),blitz::Range(1,Nux_blitz_buggy(1)-1-1))=uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(blitz::Range(0,_end),blitz::Range(2,_end))-2*uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+uxold(blitz::Range(0,_end),blitz::Range(0,Nuxold(1)-2-1)));\n'
    
  • 使用:

    ipdb> expr_code
    'ux_blitz_not_buggy(_all,blitz::Range(1,Nux_blitz_not_buggy(1)-1-1))=uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(_all,blitz::Range(2,_end))-2*uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+uxold(_all,blitz::Range(0,Nuxold(1)-2-1)));\n'
    

因此,blitz::Range(0,_end)变为_all并且他们的行为方式不同。

为方便起见,这里有一个完整的脚本可以重现问题,只有在问题出现时才会成功。

import numpy as np
from scipy.weave import blitz


def test_blitz_bug(N=4):
    ReI = 1.2
    ux_blitz_buggy, ux_blitz_not_buggy, ux_np = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N))
    uxold = np.random.randn(N, N)
    ux_np[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])
    expr_buggy = 'ux_blitz_buggy[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])'
    expr_not_buggy = 'ux_blitz_not_buggy[:,1:-1] = uxold[:,1:-1] + ReI* (uxold[:,2:] - 2*uxold[:,1:-1] + uxold[:,0:-2])'
    blitz(expr_buggy)
    blitz(expr_not_buggy)
    assert not np.allclose(ux_blitz_buggy, ux_np)
    assert np.allclose(ux_blitz_not_buggy, ux_np)

if __name__ == '__main__':
    test_blitz_bug()