FancyArrowPatch在散点图上无法正确显示

时间:2020-04-26 11:57:30

标签: python matplotlib plot

我有一个散点图,试图在上面画一个箭头。

当我尝试使用ax.add_patch(arrow1)添加补丁时,什么也没发生。

如果我使用plt.gca()。add_patch(arrow1),则会在错误的位置出现箭头。

这是我使用gca()获得的图:

Plot showing arrow in wrong place

import numpy as np
import matplotlib.patches as pat
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import matplotlib.animation as animation
import matplotlib.colors as mcolors

# initialise arrays
n = 2 # number of particles
its = 10 # number of iternations
m = np.zeros(n)
x = np.zeros((n,its+1))
v = np.zeros((n,its+1))
dt = 1000

G = 6.67408*10**(-11)

# set initial conditions
m[0], m[1] = 10**20, 10**12
x[0,0], x[1,0] = 0, 10**6
x[0,1], x[1,1] = 6.67408*10**(-5), 993325.92
v[0,1], v[1,1] = 6.67408*10**(-8), -6.67408

def leap_x(xnow, vnow):
   xleap = xnow + dt*vnow
   return xleap

def leap_v(i, xnow2, xnow1, vnow):
    a = G*m[i]*(xnow2 - xnow1)/np.abs((xnow2 - xnow1)**3)
    vleap = vnow + dt*a
    return vleap

for j in range(2, its+1):
    v[0,j] = leap_v(1, x[1,j-1], x[0,j-1], v[0,j-1])
    v[1,j] = leap_v(0, x[0,j-1], x[1,j-1], v[1,j-1])
    x[0,j] = leap_x(x[0,j-1], v[0,j])
    x[1,j] = leap_x(x[1,j-1], v[1,j])

# plotting

def sci_not(x):
    return "{:.1e}".format(x)

def make_color(num):
    alphas = np.linspace(0.8, 0.1, len(num))
    rgba_colors = np.zeros((2,len(num),4))
    rgba_colors[0,:,0] = m1Color[0] # m1 r
    rgba_colors[0,:,1] = m1Color[1] # m2 g
    rgba_colors[0,:,2] = m1Color[2] # m3 b
    rgba_colors[1,:,0] = m2Color[0] # m2 r
    rgba_colors[1,:,1] = m2Color[1] # m2 g
    rgba_colors[1,:,2] = m2Color[2] # m2 b
    rgba_colors[:,:,3] = alphas
    return rgba_colors

def reset_axes():
    plt.clf()
    ax = plt.axes(xlim=(-x[1,0]*.3,x[1,0]*1.1))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([x[0,0], x[1,0]/2, x[1,0]])
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.0E'))
    ax.set_title(Title)
    ax.set_xlabel('Distance / m')
    plt.yticks([])

def res_legend(mass1,mass2):
    lgnd = plt.legend(handles=[mass1, mass2], frameon=False)
    lgnd.legendHandles[0]._sizes = [30]
    lgnd.legendHandles[1]._sizes = [30]

def animate(i):
    reset_axes()
    numPlots = [k for k in range(i, i-4, -1) if k >= 0]
    rgba_colors = make_color(numPlots)
    newx = x[:,numPlots]
    newy = np.zeros(newx.shape)
    mass1 = plt.scatter(newx[0], newy[0], marker='o', s=scale[0], c=rgba_colors[0], label=labels[0])
    mass2 = plt.scatter(newx[1], newy[1], marker='o', s=scale[1], c=rgba_colors[1], label=labels[1])
    res_legend(mass1,mass2)
    return (mass1,mass2)

fig = plt.figure(num=1, figsize=(10,5))
Title = '2 - body sim using Leap Frog'
ax = plt.axes(xlim=(-x[1,0]*.3,x[1,0]*1.1))
scale = np.array([max(m)/min(m)*0.00005,50])
radii = np.sqrt(scale)/2
m1Color = mcolors.to_rgba('deepskyblue')
m2Color = mcolors.to_rgba('darkorchid')
labels = [str(sci_not(m[0]))+' Kg', str(sci_not(m[1]))+' Kg']
rgba_color = make_color([1])
reset_axes()
mass1 = plt.scatter(x[0,0], 0, marker='o', s=scale[0], c=rgba_color[0], label=labels[0])
mass2 = plt.scatter(x[1,0], 0, marker='o', s=scale[1], c=rgba_color[1], label=labels[1])
arrow1 = pat.FancyArrowPatch(posA=(x[0,0],0),posB=(5**5,0),arrowstyle='-|>', mutation_scale=20, 
                                    shrinkA=radii[0], shrinkB=0)
# plt.gca().add_patch(arrow1)
ax.add_patch(arrow1)
res_legend(mass1,mass2)

# anim = animation.FuncAnimation(fig, animate, frames=its+1, interval=1000, repeat=False)
plt.show()

我希望能够绘制从分散点之一的边缘开始的箭头。指向另一个。我将根据点的速度缩放其长度(每个点代表重力吸引到另一个物体的质量,该图像是其位置动画的第一帧)。

1 个答案:

答案 0 :(得分:0)

ax.add_patch()的调用不起作用,因为您正在清除图形(plt.clf())并在函数ax中创建了一个新的reset_axes()对象。但是该ax对象是函数的本地对象,不会覆盖全局范围中存在的ax对象。但是,原来的ax对象现在无效,因为调用plt.clf()

时原来的轴已被破坏。

通常,您应该避免不必要的清除图形(甚至轴),尤其是在动画环境中。

一种快速解决方案是像这样更改reset_axes()

def reset_axes(ax=None):
    ax = ax or plt.gca()
    ax.cla()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([x[0,0], x[1,0]/2, x[1,0]])
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.0E'))
    ax.set_title(Title)
    ax.set_xlabel('Distance / m')
    ax.set_yticks([])

这样,您就不会破坏原始轴,只需清除内容即可。