使用blit = True时的matlibplot FuncAnimation问题

时间:2019-12-06 23:01:48

标签: python-3.x

我正在尝试使用FuncAnimation对三个地块进行动画处理,但出现错误

_Traceback (most recent call last): File "/Users/XX/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/__init__.py", line 216, in process func(*args, **kwargs) File "/Users/XX/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 953, in _start self._init_draw() File "/Users/XX/anaconda3/lib/python3.7/site-packages/matplotlib/animation.py", line 1741, in _init_draw a.set_animated(self._blit) AttributeError: 'list' object has no attribute 'set_animated'

如果我改为使用blit=True,则该代码有效。看来问题与return line,有关,根据其他帖子,应该返回一个列表而不是一个元组,但是我不确定返回line的正确方法是什么。

谢谢

我的代码是

import math
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
dt = 1e-5
dx = 5e-3
s = dt / dx**2
x_min = -1
x_max = 2
x = np.arange(x_min, x_max, dx)
print("s=",s)

def norm(f):
    return np.sum(f*f.conjugate()).real*dx

def norm_sep(fre, fim):
    return np.sum(fre*fre+fim*fim).real*dx

def initial_psi(x):
    A = 1
    x0 = 0.2
    w = 0.1
    k0 = 100
    temp = np.exp(-(x-x0)**2/w**2)*np.exp(1j*k0*x)
    return np.exp(-(x-x0)**2/w**2)*np.exp(1j*k0*x) / np.sqrt(norm(temp))

def potential(x):
    V_temp = np.zeros_like(x)
    mask = (x > (x_max+x_min)/2) & (x < 3*(x_max+x_min)/4)
    V0 = 5000
    V_temp[mask] = V0
    return V_temp

print(x.shape)
V = potential(x)
psi0 = initial_psi(x)
print(norm(psi0))
plt.figure()
plt.plot(x, V)
plt.plot(x, psi0.real)
plt.plot(x, psi0.imag)
plt.plot(x, (psi0*psi0.conjugate()).real)
#plt.show()

psi_re_old = np.zeros_like(x)
psi_re_new = np.zeros_like(x)
psi_im_old = np.zeros_like(x)
psi_im_new = np.zeros_like(x)

def density(m, psi_re_old_, psi_re_new_, psi_im_old_, psi_im_new_):
    if(m%2 == 0):
        density = psi_re_new_*psi_re_new_+psi_im_old_*psi_im_new_
    else:
        density = psi_re_new_*psi_re_old_+psi_im_new_*psi_im_new_
    return density

m=0
psi_re_old = psi0.real.copy()
psi_re_new = psi0.real.copy()
psi_im_old = psi0.imag.copy()
psi_im_new = psi0.imag.copy()

m=1
psi_im_new[1:-1] = psi_im_old[1:-1] + s*psi_re_new[0:-2] + s*psi_re_new[2:] - 2*(s+V[1:-1]*dt)*psi_re_new[1:-1]

def increment_step(m, psi_re_old, psi_re_new, psi_im_old, psi_im_new):
    m = m+1

    if(m%2 == 0):
        psi_re_old = psi_re_new.copy()
        psi_re_new[1:-1] = psi_re_old[1:-1] - s*psi_im_new[0:-2] - s*psi_im_new[2:] + 2*(s+V[1:-1]*dt)*psi_im_new[1:-1]

    else:
        psi_im_old = psi_im_new.copy()
        psi_im_new[1:-1] = psi_im_old[1:-1] + s*psi_re_new[0:-2] + s*psi_re_new[2:] - 2*(s+V[1:-1]*dt)*psi_re_new[1:-1]
    return (m, psi_re_old, psi_re_new, psi_im_old, psi_im_new)


#steps = 100000
#for i in range(0, steps):
#    (m, psi_re_old, psi_re_new, psi_im_old, psi_im_new) = increment_step(m, psi_re_old, psi_re_new, psi_im_old, psi_im_new)

plt.figure()

plt.plot(x, psi_re_new)
plt.plot(x, psi_im_new)
plt.plot(x, density(m, psi_re_old, psi_re_new, psi_im_old, psi_im_new))
max_val_wf = np.max(psi_re_new)
max_val_V = np.max(V)
plt.plot(x, V*max_val_wf/max_val_V)
#plt.show()

# create a figure with two subplots
fig, (ax1, ax2, ax3) = plt.subplots(3,1)

# intialize two line objects (one in each axes)
line1, = ax1.plot([], [], lw=2)
line2, = ax2.plot([], [], lw=2, color='g')
line3, = ax3.plot([], [], lw=2, color='r')
ax1.plot(x, V*max_val_wf/max_val_V)
ax2.plot(x, V*max_val_wf/max_val_V)
ax3.plot(x, V*max_val_wf/max_val_V)
plt.tight_layout()
line = [line1, line2, line3]

# the same axes initalizations as before (just now we do it for both of them)
ax1.set_ylim(-2,2)
ax1.set_xlim(x_min, x_max)
ax1.grid()
ax1.set_ylabel(r'Re$\Psi(x,t)$')
ax1.set_title("Evolution of Wavefunction and Probability Density in real space")

ax2.set_ylim(-2, 2)
ax2.set_xlim(x_min, x_max)
ax2.grid()
ax2.set_ylabel(r'Im$\Psi(x,t)$')

ax3.set_ylim(0, 10)
ax3.set_xlim(x_min, x_max)
ax3.grid()
ax3.set_xlabel(r'$x$')
ax3.set_ylabel(r'$|\Psi(x,t)|^2$')

time_text = ax1.text(0.75, 0.95,'',horizontalalignment='left',verticalalignment='top', transform=ax1.transAxes)
step_text = ax1.text(0.75, 0.85,'',horizontalalignment='left',verticalalignment='top', transform=ax1.transAxes)
norm_text = ax1.text(0.75, 0.75,'',horizontalalignment='left',verticalalignment='top', transform=ax1.transAxes)
time_per_frame = dt     # You could change this to make the animation faster or slower

def run_init():
    line[0].set_data([], [])
    line[1].set_data([], [])
    line[2].set_data([], [])
    return line,

def run(i, psi_re_old, psi_re_new, psi_im_old, psi_im_new):
    m = i+1
    (m, psi_re_old, psi_re_new, psi_im_old, psi_im_new) = increment_step(m, psi_re_old, psi_re_new, psi_im_old, psi_im_new)
    y1data = psi_re_new
    y2data = psi_im_new
    y3data = density(m-1, psi_re_old, psi_re_new, psi_im_old, psi_im_new)

    # update the data of the three line objects
    line[0].set_data(x, y1data)
    line[1].set_data(x, y2data)
    line[2].set_data(x, y3data)
    #Display the current animation time
    time_text.set_text('tau = %.8f' % (i*time_per_frame))
    step_text.set_text('step = ' + str(i))
    norm_text.set_text('norm = %.2f' % (np.sum(y3data)*dx))
    print(dt)
    return line, time_text, step_text, norm_text
#
ani = FuncAnimation(fig, run, fargs = (psi_re_old, psi_re_new, psi_im_old, psi_im_new), init_func=run_init,
                               frames=1000, interval=1, blit=True)

plt.show()```

0 个答案:

没有答案