我想使用weave.blitz来改善以下numpy代码的性能:
def fastIteration(self):
g = self.grid
nx,ny = g.ux.shape
uxold = g.old_ux
ux = g.ux
ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])
g.setBC()
g.old_ux = ux.copy()
在此代码中,g是计算网格。其中包括两个不同的字段ux和uxold。旧版仅用于临时存储变量。在完整的代码中,大约95%的运行时花费在fastIteration方法中,因此即使简单的性能提升也会减少执行此代码所花费的时间。
numpy方法的输出看起来好像:
由于此代码是我的瓶颈,我想通过使用编织闪电来提高速度。此方法如下所示:
def blitzIteration(self):
### does not work correct so far
g = self.grid
nx,ny = g.ux.shape
uxold = g.old_ux
ux = g.ux
expr = "ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])"
weave.blitz(expr, check_size=0)
g.setBC()
g.old_ux = ux.copy()
然而,这不会产生正确的输出:
答案 0 :(得分:2)
它看起来像weave.blitz
中的一个错误(转载,归档和fixed。有关于那里的实际错误的更多信息)。
我认为编写0:
而不是较短的:
来获得一个完整的切片是很奇怪的,所以我替换了所有这些切片并且vo,它起作用了。
我真的不知道错误的位置,但weave.blitz
生成的expr_code
略有不同:
使用0:
ipdb> expr_code
'ux_blitz_buggy(blitz::Range(0,_end),blitz::Range(1,Nux_blitz_buggy(1)-1-1))=uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(blitz::Range(0,_end),blitz::Range(2,_end))-2*uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+uxold(blitz::Range(0,_end),blitz::Range(0,Nuxold(1)-2-1)));\n'
使用:
ipdb> expr_code
'ux_blitz_not_buggy(_all,blitz::Range(1,Nux_blitz_not_buggy(1)-1-1))=uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(_all,blitz::Range(2,_end))-2*uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+uxold(_all,blitz::Range(0,Nuxold(1)-2-1)));\n'
因此,blitz::Range(0,_end)
变为_all
并且他们的行为方式不同。
为方便起见,这里有一个完整的脚本可以重现问题,只有在问题出现时才会成功。
import numpy as np
from scipy.weave import blitz
def test_blitz_bug(N=4):
ReI = 1.2
ux_blitz_buggy, ux_blitz_not_buggy, ux_np = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N))
uxold = np.random.randn(N, N)
ux_np[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])
expr_buggy = 'ux_blitz_buggy[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])'
expr_not_buggy = 'ux_blitz_not_buggy[:,1:-1] = uxold[:,1:-1] + ReI* (uxold[:,2:] - 2*uxold[:,1:-1] + uxold[:,0:-2])'
blitz(expr_buggy)
blitz(expr_not_buggy)
assert not np.allclose(ux_blitz_buggy, ux_np)
assert np.allclose(ux_blitz_not_buggy, ux_np)
if __name__ == '__main__':
test_blitz_bug()