使用numpy加速代码

时间:2014-10-21 13:58:23

标签: python fft physics

我是python的新手,我有这个代码用于计算使用傅里叶级数的1x1盒子内的潜力,但是它的一部分太慢了(在下面的代码中标记)。

如果有人可以帮我这个,我怀疑我可以用numpy库做些什么,但我对它并不熟悉。

import matplotlib.pyplot as plt
import pylab
import sys
from matplotlib import rc
rc('text', usetex=False)
rc('font', family = 'serif')

#One of the boundary conditions for the potential.

def func1(x,n):
    V_c = 1
    V_0 = V_c * np.sin(n*np.pi*x)
    return V_0*np.sin(n*np.pi*x)

#To calculate the potential inside a box:
def v(x,y):
    n = 1;
    sum = 0;
    nmax = 20;

    while n < nmax:
        [C_n, err] = quad(func1, 0, 1, args=(n), );
        sum = sum + 2*(C_n/np.sinh(np.pi*n)*np.sin(n*np.pi*x)*np.sinh(n*np.pi*y));
        n = n + 1;


    return sum;

def main(argv):
    x_axis = np.linspace(0,1,100)
    y_axis = np.linspace(0,1,100)
    V_0 = np.zeros(100)
    V_1 = np.zeros(100)

    n = 4;

    #Plotter for V0 = v_c * sin () x

    for i in range(100):
        V_0[i] = V_0_1(i/100, n)

    plt.plot(x_axis, V_0)
    plt.xlabel('x/L')
    plt.ylabel('V_0')
    plt.title('V_0(x) = sin(m*pi*x/L), n = 4')
    plt.show()

    #Plot for V_0 = V_c(1-(x-1/2)^4)

    for i in range(100):
        V_1[i] = V_0_2(i/100)


    plt.figure()

    plt.plot(x_axis, V_1)
    plt.xlabel('x/L')
    plt.ylabel('V_0')
    plt.title('V_0(x) = 1- (x/L - 1/2)^4)')
    #plt.legend()
    plt.show()

    #Plot V(x/L,y/L) on the boundary:

    V_0_Y = np.zeros(100)
    V_1_Y = np.zeros(100)
    V_X_0 = np.zeros(100)
    V_X_1 = np.zeros(100)

    for i in range(100):
        V_0_Y[i] = v(0, i/100)
        V_1_Y[i] = v(1, i/100)
        V_X_0[i] = v(i/100, 0)
        V_X_1[i] = v(i/100, 1)

    # V(x/L = 0, y/L):

    plt.figure()
    plt.plot(x_axis, V_0_Y)
    plt.title('V(x/L = 0, y/L)')
    plt.show()

    # V(x/L = 1, y/L):

    plt.figure()
    plt.plot(x_axis, V_1_Y)
    plt.title('V(x/L = 1, y/L)')
    plt.show()

    # V(x/L, y/L = 0):

    plt.figure()
    plt.plot(x_axis, V_X_0)
    plt.title('V(x/L, y/L = 0)')
    plt.show()

    # V(x/L, y/L = 1):

    plt.figure()
    plt.plot(x_axis, V_X_1)
    plt.title('V(x/L, y/L = 1)')
    plt.show()

    #Plot V(x,y)
####### 
# This is where the code is way too slow, it takes like 10 minutes when n in v(x,y) is 20.
#######

    V = np.zeros(10000).reshape((100,100))
    for i in range(100):
        for j in range(100):
            V[i,j] = v(j/100, i/100)

    plt.figure()
    plt.contour(x_axis, y_axis, V,  50)
    plt.savefig('V_1')
    plt.show()

if __name__ == "__main__":
    main(sys.argv[1:])

2 个答案:

答案 0 :(得分:0)

您可以在本文档中找到如何使用FFT / DFT:

Discretized continuous Fourier transform with numpy

另外,关于V矩阵,有很多方法可以提高执行速度。一种是确保使用Python 3,或者xrange()而不是range(),如果你还在Python 2中。。我通常将这些行放在我的Python代码中,以便在使用Python 3时能够均匀运行。或2. *

# Don't want to generate huge lists in memory... use standard range for Python 3.*
range = xrange if isinstance(range(2),
                             list) else range

然后,您可以预先计算这些值并将它们放入数组中,而不是重新计算j/100i/100。知道分裂比乘法更昂贵!类似的东西:

ratios = np.arange(100) / 100

V = np.zeros(10000).reshape((100,100))
j = 0
while j < 100:
    i = 0
    while i < 100:
        V[i,j] = v(values[j], values[i])
        i += 1
    j += 1

嗯,无论如何,这是相当美观的,不会挽救你的生命;你还需要调用函数v() ...

然后,您可以使用编织:

http://docs.scipy.org/doc/scipy-0.14.0/reference/tutorial/weave.html

或者在C中编写所有纯计算/循环代码,编译它并生成一个可以从Python调用的模块。

答案 1 :(得分:0)

你应该研究numpy的broadcasting技巧和vectorization(几个引用,弹出的第一个好链接之一来自Matlab,但它适用于numpy - 任何人都可以推荐我我可能在将来指向其他用户的评论中的好的numpy链接?)。

我在您的代码中看到的内容(一旦删除了所有不必要的内容,如绘图和未使用的函数),您基本上就是这样做了:

from __future__ import division
from scipy.integrate import quad
import numpy as np
import matplotlib.pyplot as plt

def func1(x,n):
    return 1*np.sin(n*np.pi*x)**2

def v(x,y):
    n = 1;
    sum = 0;
    nmax = 20;

    while n < nmax:
        [C_n, err] = quad(func1, 0, 1, args=(n), );
        sum = sum + 2*(C_n/np.sinh(np.pi*n)*np.sin(n*np.pi*x)*np.sinh(n*np.pi*y));
        n = n + 1;
    return sum;

def main():
    x_axis = np.linspace(0,1,100)
    y_axis = np.linspace(0,1,100)
#######
# This is where the code is way too slow, it takes like 10 minutes when n in v(x,y) is 20.
#######

    V = np.zeros(10000).reshape((100,100))
    for i in range(100):
        for j in range(100):
            V[i,j] = v(j/100, i/100)

    plt.figure()
    plt.contour(x_axis, y_axis, V,  50)
    plt.show()

if __name__ == "__main__":
    main()

如果你仔细观察(你也可以使用一个探查器),你会发现你正在将你的函数func1(我将重命名为integrand)整合大约20次100x100阵列中的每个元素V.但是,被积函数不会改变!所以你已经可以把它带出你的循环了。如果你这样做,并使用广播技巧,你可能会得到这样的结论:

import numpy as np
from scipy.integrate import quad
import matplotlib.pyplot as plt

def integrand(x,n):
    return 1*np.sin(n*np.pi*x)**2

sine_order = np.arange(1,20).reshape(-1,1,1) # Make an array along the third dimension
integration_results = np.empty_like(sine_order, dtype=np.float)
for enu, order in enumerate(sine_order):
    integration_results[enu] = quad(integrand, 0, 1, args=(order,))[0]

y,x = np.ogrid[0:1:.01, 0:1:.01]

term = integration_results / np.sinh(np.pi * sine_order) * np.sin(sine_order * np.pi * x) * np.sinh(sine_order * np.pi * y)
# This is the key: you have a 3D matrix here and with this summation, 
# you're basically squashing the entire 3D structure into a flat, 2D 
# representation. This 'squashing' is done by means of a sum.
V = 2*np.sum(term, axis=0)  

x_axis = np.linspace(0,1,100)
y_axis = np.linspace(0,1,100)
plt.figure()
plt.contour(x_axis, y_axis, V,  50)
plt.show()

在我的系统上运行不到一秒钟。 如果您使用笔和纸,并从基本的俄罗斯方块块中抽出您正在“广播”的向量,就像您正在构建建筑物一样,广播变得更容易理解。

这两个版本在功能上是相同的,但一个是完全矢量化的,而另一个使用python for-loops。作为python和numpy的新用户,我绝对建议您阅读广播基础知识。祝你好运!