使用Polyfit预测行星的轨迹

时间:2019-11-29 19:25:47

标签: python numpy matplotlib

我正在模拟三体问题,并绘制3D轨迹。我试图弄清楚如何通过使用np.polyfit扩展绘图线来预测这些行星的轨迹。我有在数据框和2D绘图上执行此操作的经验,但在3D和不使用任何类型的数据框的情况下却没有。我提供了整个代码,扩展尝试在图形下方,包括错误消息。我正在寻找有关如何修改当前代码(尤其是扩展绘图的代码部分)以实现此目的的任何建议。 代码:

from scipy.integrate import odeint
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

# Universal Gravitational Const.

G = 6.674e-11

# Defining Mass

m1 = 0.9
m2 = 3.5
m3 = 1.6

# Init positions in graph (array)

pos1 = [-5,3,1]
pos2 = [5,12,10]
pos3 = [-7,1,27]

p01 = np.array(pos1)
p02 = np.array(pos2)
p03 = np.array(pos3)

# Init velocities (array)

vi1 = [10,-2,3]
vi2 = [-1,3,2]
vi3 = [3,-1,-6]

v01 = np.array(vi1)
v02 = np.array(vi2)
v03 = np.array(vi3)


#Function
def derivs_func(y,t,G,m1,m2,m3):
    d1 = np.array([y[0],y[1],y[2]]) #Unpacking the variables
    d2 = np.array([y[3],y[4],y[5]])   
    d3 = np.array([y[6],y[7],y[8]])
    v1 = np.array([y[9],y[10],y[11]])
    v2 = np.array([y[12],y[13],y[14]])
    v3 = np.array([y[15],y[16],y[17]])

    #Distance between objects
    dist12 = np.sqrt((pos2[0]-pos1[0])**2 + (pos2[1]-pos1[1])**2 + (pos2[2]-pos1[2])**2) 
    dist13 = np.sqrt((pos3[0]-pos1[0])**2 + (pos3[1]-pos1[1])**2 + (pos3[2]-pos1[2])**2)
    dist23 = np.sqrt((pos3[0]-pos2[0])**2 + (pos3[1]-pos2[1])**2 + (pos3[2]-pos2[2])**2)

    #Derivative equations: change in velocity and position
    dv1dt = m2 * (d2-d1)/dist12**3 + m3 * (d3-d1)/dist13**3 
    dv2dt = m1 * (d1-d2)/dist12**3 + m3 * (d3-d2)/dist23**3 
    dv3dt = m1 * (d1-d3)/dist13**3 + m2 * (d2-d3)/dist23**3 
    dd1dt = v1 
    dd2dt = v2
    dd3dt = v3

    derivs = np.array([dd1dt,dd2dt,dd3dt,dv1dt,dv2dt,dv3dt])  #Adding derivatives into an array
    derivs3 = derivs.flatten() #Turning the array into a 1D array

    return derivs3 #Returning the flattened array

yo = np.array([p01, p02, p03, v01, v02, v03]) #Initial conditions for position and velocity
y0 = yo.flatten()  #Turning the array into a 1D array

time = np.linspace(0,500,500) #Defining time

sol = odeint(derivs_func, y0, time, args = (G,m1,m2,m3)) #Calling the odeint function

x1 = sol[:,:3]
x2 = sol[:,3:6]
x3 = sol[:,6:9]


fig = plt.figure(figsize = (15,15)) #Creating a 3D plot
ax = plt.axes(projection = '3d')

ax.plot(x1[:,0],x1[:,1],x1[:,2],color = 'b') #Plotting the paths each planet takes
ax.plot(x2[:,0],x2[:,1],x2[:,2],color = 'r')
ax.plot(x3[:,0],x3[:,1],x3[:,2],color = 'g')

ax.scatter(x1[-1,0],x1[-1,1],x1[-1,2],color = 'b', marker = 'o', s=45, label = 'Mass 1') 
ax.scatter(x2[-1,0],x2[-1,1],x2[-1,2],color = 'r', marker = 'o',s=200, label = 'Mass 2')  
ax.scatter(x3[-1,0],x3[-1,1],x3[-1,2],color = 'g', marker = 'o',s=100, label = 'Mass 3')

ax.legend()

enter image description here

fig = plt.figure(figsize = (15,15))
ax = plt.axes(projection = '3d')

fit1 = np.poly1d(np.polyfit(x1[:,0],x1[:,1],7))
fit12 = np.poly1d(np.polyfit(x1[:,0],x1[:,2],7))
fit2 = np.poly1d(np.polyfit(x2[:,0],x2[:,1],7))
fit22 = np.poly1d(np.polyfit(x2[:,0],x2[:,2],7))
fit3 = np.poly1d(np.polyfit(x3[:,0],x3[:,1],7))
fit32 = np.poly1d(np.polyfit(x3[:,0],x3[:,2],7))

y1 = fit1(x1[:,0])
y12 = fit12(x1[:,0])
y2 = fit2(x2[:,0])
y22 = fit22(x2[:,0])
y3 = fit3(x3[:,0])
y32 = fit32(x3[:,0])

extended1 = np.linspace(x1[-1,0], x1[-1,0] + 300, 1)
extended2 = np.linspace(x2[-1,0], x2[-1,0] + 300, 1)
extended3 = np.linspace(x3[-1,0], x3[-1,0] + 300, 1)

yex1 = fit1(extended1)
yex12 = fit12(extended1)
yex2 = fit2(extended2)
yex22 = fit22(extended2)
yex3 = fit3(extended3)
yex32 = fit32(extended3)

ax.plot(x1[:,0],x1[:,1],x1[:,2])
ax.plot(x1[:,0],yex1,yex12)
ax.plot(x2[:,0],x2[:,1],x2[:,2])
ax.plot(x2[:,0],yex2,yex22)
ax.plot(x3[:,0],x3[:,1],x3[:,2])
ax.plot(x3[:,0],yex3,yex32)

错误消息:

Traceback (most recent call last)
<ipython-input-98-a55893800c7b> in <module>
     28 
     29 ax.plot(x1[:,0],x1[:,1],x1[:,2])
---> 30 ax.plot(x1[:,0],yex1,yex12)
     31 ax.plot(x2[:,0],x2[:,1],x2[:,2])
     32 ax.plot(x2[:,0],yex2,yex22)

~\Downloads\Anaconda\lib\site-packages\mpl_toolkits\mplot3d\axes3d.py in 
 plot(self, xs, ys, zdir, *args, **kwargs)
    1530         zs = np.broadcast_to(zs, len(xs))
    1531 
 -> 1532         lines = super().plot(xs, ys, *args, **kwargs)
    1533         for line in lines:
    1534             art3d.line_2d_to_3d(line, zs=zs, zdir=zdir)

~\Downloads\Anaconda\lib\site-packages\matplotlib\axes\_axes.py in 
plot(self, scalex, scaley, data, *args, **kwargs)
   1664         """
   1665         kwargs = cbook.normalize_kwargs(kwargs, 
mlines.Line2D._alias_map)
-> 1666         lines = [*self._get_lines(*args, data=data, **kwargs)]
   1667         for line in lines:
   1668             self.add_line(line)

~\Downloads\Anaconda\lib\site-packages\matplotlib\axes\_base.py in 
__call__(self, *args, **kwargs)
    223                 this += args[0],
    224                 args = args[1:]
--> 225             yield from self._plot_args(this, kwargs)
    226 
    227     def get_next_color(self):

~\Downloads\Anaconda\lib\site-packages\matplotlib\axes\_base.py in 
_plot_args(self, tup, kwargs)
    389             x, y = index_of(tup[-1])
    390 
--> 391         x, y = self._xy_from_xy(x, y)
    392 
    393         if self.command == 'plot':

~\Downloads\Anaconda\lib\site-packages\matplotlib\axes\_base.py in 
_xy_from_xy(self, x, y)
    268         if x.shape[0] != y.shape[0]:
    269             raise ValueError("x and y must have same first 
dimension, but "
--> 270                              "have shapes {} and {}".format(x.shape, 
y.shape))
    271         if x.ndim > 2 or y.ndim > 2:
    272             raise ValueError("x and y can be no greater than 2-D, 
but have "

ValueError: x and y must have same first dimension, but have shapes (500,) 
and (1,)

1 个答案:

答案 0 :(得分:1)

np.polyfit返回一个系数数组:

>>> np.polyfit(np.arange(4), np.arange(4), 1)
array([1.00000000e+00, 1.12255857e-16])

要将其转换为可调用的多项式,请对结果使用np.poly1d

>>> p = np.poly1d(np.polyfit(np.arange(4), np.arange(4), 1))
>>> p(1)
1.0000000000000002

因此,在您的项目中,更改以下行:

fit1 = np.polyfit(x1[:,0],x1[:,1],7)
# etc.

fit1 = np.poly1d(np.polyfit(x1[:,0],x1[:,1],7))
# etc.

编辑:您的新错误似乎源于以下事实:延伸轴各具有2个尺寸:

extended1 = np.linspace(x1[-1,:], x1[-1,:] + 300, 1) # extended1.ndim == 2 !
extended2 = np.linspace(x2[-1,:], x2[-1,:] + 300, 1)
extended3 = np.linspace(x3[-1,:], x3[-1,:] + 300, 1)

如果我正确理解了您的代码,那么您要执行的操作是

extended1 = np.arange(x1[-1, 0], x1[-1, 0] + 300)
extended2 = np.arange(x2[-1, 0], x2[-1, 0] + 300)
extended3 = np.arange(x3[-1, 0], x3[-1, 0] + 300)

以及下面的内容:

ax.plot(x1[:,0],x1[:,1],x1[:,2])
ax.plot(extended1,yex1,yex12)
ax.plot(x2[:,0],x2[:,1],x2[:,2])
ax.plot(extended2,yex2,yex22)
ax.plot(x3[:,0],x3[:,1],x3[:,2])
ax.plot(extended3,yex3,yex32)