有什么办法可以加快这段代码的速度吗?它通过碰撞计算pi

时间:2019-05-03 15:53:01

标签: python performance matplotlib runtime collision

此代码通过碰撞计算pi;它要求用户输入N来确定第二块的质量。它可以正常工作,当N> = 2时,它将永远需要运行。我希望能够至少具有N = 5并具有合理的运行时间。 我认为问题在于速度和x1,x2计算。有太多需要附加的东西,它要花很长时间才能运行,然后才需要很长时间才能附加。

我的代码表明,我尝试使用Numba来加快运行时间,但这似乎无济于事。我当前正在使用RK-4方法更新位置,并且以前尝试过Verlet方法,该方法似乎对运行时没有影响。 在此方面的任何帮助将不胜感激。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
from numba import jit
import time

start = time.time()

@jit(nopython=True)
def RK4_1(x, vx):
    return vx

@jit(nopython=True)
def RK4_2(x, vx):
    return 0

@jit(nopython=True)
def iterate(x1, x2, vx1, vx2, col):
    k1 = dt*RK4_1(x1, vx1)
    k2 = dt*RK4_1(x1 + k1/2, vx1)
    k3 = dt*RK4_1(x1 + k2/2, vx1)
    k4 = dt*RK4_1(x1 + k3, vx1)
    x1 += (1/6)*(k1 + 2*k2 + 2*k3 + k4)

    k1 = dt*RK4_1(x2, vx2)
    k2 = dt*RK4_1(x2 + k1/2, vx2)
    k3 = dt*RK4_1(x2 + k2/2, vx2)
    k4 = dt*RK4_1(x2 + k3, vx2)
    x2 += (1/6)*(k1 + 2*k2 + 2*k3 + k4)

    k1 = dt*RK4_2(x1, vx1)
    k2 = dt*RK4_2(x1, vx1 + k1/2)
    k3 = dt*RK4_2(x1, vx1 + k2/2)
    k4 = dt*RK4_2(x1, vx1 + k3)
    vx1 += (1/6)*(k1 + 2*k2 + 2*k3 + k4)

    k1 = dt*RK4_2(x2, vx2)
    k2 = dt*RK4_2(x2, vx2 + k1/2)
    k3 = dt*RK4_2(x2, vx2 + k2/2)
    k4 = dt*RK4_2(x2, vx2 + k3)
    vx2 += (1/6)*(k1 + 2*k2 + 2*k3 + k4)

    if x1 < 0:
        x1 = 0
        vx1 = -vx1
        col += 1

    if x2 < x1:
        x2 = x1
        vx1_i = vx1
        vx2_i = vx2
        vx1 = (2*m2*vx2_i + m1*vx1_i - m2*vx1_i)/(m1+m2)
        vx2 = (2*m1*vx1_i + m2*vx2_i - m1*vx2_i)/(m1+m2)
        col += 1

    return x1, x2, vx1, vx2, col

dt = 0.01

m1 = 1
N = int(input("Enter an integer N that will determine the mass of the second block: "))
m2 = 100**N

w1 = 1
w2 = w1*(100**N)**(1/3)
x1 = 1
x2 = 1.15
y1 = 1
y2 = 1
vx1 = 0
vx2 = -1
col = 0

x1arr = np.array([])
x2arr = np.array([])
y1arr = np.array([])
y2arr = np.array([])
vx1arr = np.array([])
vx2arr = np.array([])
colarr = np.array([])

t = 0

while (vx2 < 0) or (abs(vx1) > abs(vx2)):
    x1, x2, vx1, vx2, col = iterate(x1, x2, vx1, vx2, col)
    #print(vx1, vx2)
    t += dt

    x1arr = np.append(x1arr, x1)
    x2arr = np.append(x2arr, x2)
    y1arr = np.append(y1arr, y1)
    y2arr = np.append(y2arr, y2)
    vx1arr = np.append(vx1arr, vx1)
    vx2arr = np.append(vx2arr, vx2)
    colarr = np.append(colarr, col)

print("Number of collisions: %f" % (col))

speed = 1000

def update_plot(i, fig, scat1, scat2, txt):
    last = 0
    if(i>int(len(x1arr)/speed)-2):
        last=1  
    s = int(speed*i)
    scat1.set_data(x1arr[s],y1arr[s])
    scat2.set_data(x2arr[s],y2arr[s])
    #scat.set_sizes(lx1=5, lx2=5*N)
    txt.set_text('x1= %.3f   m1=%.0f\nx2= %.3f   m1=%.0f\nCollisions=%.0f\n t=%.3fs' % (x1arr[s],m1,x2arr[s],m2,colarr[s]+last,(s*dt))) #update of legend
    print("Frame %d Rendered" % (s))
    return scat1, scat2, txt,

size = 5*N
fig =  plt.figure() 
ax = fig.add_subplot(111)
ax.set_xlim([0, 3]) #animation scale
ax.set_ylim([0,5])
ax.grid()
txt = ax.text(0.05, 0.8, '', transform=ax.transAxes) 
scat1, = ax.plot([], [],'s', c='r', markersize=5) 
scat2, = ax.plot([], [],'s', c='r', markersize=5*(N+1))
anim = FuncAnimation(fig, update_plot, fargs = (fig, scat1, scat2, txt), frames = int(len(x1arr)/speed), interval = 1, blit=True, repeat=False)
anim.save("originalpi.mp4", fps=30, bitrate=-1)

end = time.time()
print("Total runtime in seconds:  ", end-start)

#plt.show()

如果我设置N = 0或N = 1,则运行时间是合理的。但是只要N = 2,运行时间就会增加到大约700秒。

0 个答案:

没有答案