Python:动画3D散点图变慢

时间:2016-11-30 14:16:06

标签: python animation matplotlib scatter3d

我的程序在每个时间步骤中绘制我文件中粒子的位置。不幸的是,虽然我使用了matplotlib.animation,但速度越来越慢。瓶颈在哪里?

我的两个粒子的数据文件如下所示:

#     x   y   z
# t1  1   2   4
#     4   1   3
# t2  4   0   4
#     3   2   9
# t3  ...

我的剧本:

import numpy as np                          
import matplotlib.pyplot as plt            
from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation

# Number of particles
numP = 2
# Dimensions
DIM = 3
timesteps = 2000

with open('//home//data.dat', 'r') as fp:
    particleData = []
    for line in fp:
        line = line.split()
        particleData.append(line)

x = [float(item[0]) for item in particleData]
y = [float(item[1]) for item in particleData]
z = [float(item[2]) for item in particleData]      

# Attaching 3D axis to the figure
fig = plt.figure()
ax = p3.Axes3D(fig)

# Setting the axes properties
border = 1
ax.set_xlim3d([-border, border])
ax.set_ylim3d([-border, border])
ax.set_zlim3d([-border, border])


def animate(i):
    global x, y, z, numP
    #ax.clear()
    ax.set_xlim3d([-border, border])
    ax.set_ylim3d([-border, border])
    ax.set_zlim3d([-border, border])
    idx0 = i*numP
    idx1 = numP*(i+1)
    ax.scatter(x[idx0:idx1],y[idx0:idx1],z[idx0:idx1])

ani = animation.FuncAnimation(fig, animate, frames=timesteps, interval=1, blit=False, repeat=False)
plt.show()

2 个答案:

答案 0 :(得分:2)

我建议在这种情况下使用pyqtgraph。来自文档的引文:

  

其主要目标是1)为其提供快速,交互式图形   显示数据(情节,视频等)和2)提供辅助工具   快速应用程序开发(例如,属性树,如   在Qt Designer中使用。)

您可以在安装后查看一些示例:

import pyqtgraph.examples
pyqtgraph.examples.run()

此小代码段生成1000个随机点,并通过不断更新不透明度将其显示在3D散点图中,类似于pyqtgraph.examples中的3D散点图示例:

from pyqtgraph.Qt import QtCore, QtGui
import pyqtgraph.opengl as gl
import numpy as np

app = QtGui.QApplication([])
w = gl.GLViewWidget()
w.show()
g = gl.GLGridItem()
w.addItem(g)

#generate random points from -10 to 10, z-axis positive
pos = np.random.randint(-10,10,size=(1000,3))
pos[:,2] = np.abs(pos[:,2])

sp2 = gl.GLScatterPlotItem(pos=pos)
w.addItem(sp2)

#generate a color opacity gradient
color = np.zeros((pos.shape[0],4), dtype=np.float32)
color[:,0] = 1
color[:,1] = 0
color[:,2] = 0.5
color[0:100,3] = np.arange(0,100)/100.

def update():
    ## update volume colors
    global color
    color = np.roll(color,1, axis=0)
    sp2.setData(color=color)

t = QtCore.QTimer()
t.timeout.connect(update)
t.start(50)


## Start Qt event loop unless running in interactive mode.
if __name__ == '__main__':
    import sys
    if (sys.flags.interactive != 1) or not hasattr(QtCore, PYQT_VERSION'):
        QtGui.QApplication.instance().exec_()

小gif让你了解表现:

enter image description here

修改

在每个时间步骤显示多个点有点棘手,因为gl.GLScatterPlotItem仅将(N,3)-arrays作为点位置,请参阅here。您可以尝试制作ScatterPlotItems字典,其中每个字典都包含特定点的所有时间步长。然后需要相应地调整更新功能。您可以在下面找到一个示例,其中pos(100,10,3)-array,表示每个点的100个时间步长。我将更新时间缩短为1000 ms以获得较慢的动画。

from pyqtgraph.Qt import QtCore, QtGui
import pyqtgraph.opengl as gl
import numpy as np

app = QtGui.QApplication([])
w = gl.GLViewWidget()
w.show()
g = gl.GLGridItem()
w.addItem(g)

pos = np.random.randint(-10,10,size=(100,10,3))
pos[:,:,2] = np.abs(pos[:,:,2])

ScatterPlotItems = {}
for point in np.arange(10):
    ScatterPlotItems[point] = gl.GLScatterPlotItem(pos=pos[:,point,:])
    w.addItem(ScatterPlotItems[point])

color = np.zeros((pos.shape[0],10,4), dtype=np.float32)
color[:,:,0] = 1
color[:,:,1] = 0
color[:,:,2] = 0.5
color[0:5,:,3] = np.tile(np.arange(1,6)/5., (10,1)).T

def update():
    ## update volume colors
    global color
    for point in np.arange(10):
        ScatterPlotItems[point].setData(color=color[:,point,:])
    color = np.roll(color,1, axis=0)

t = QtCore.QTimer()
t.timeout.connect(update)
t.start(1000)


## Start Qt event loop unless running in interactive mode.
if __name__ == '__main__':
    import sys
    if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'):
    QtGui.QApplication.instance().exec_()

请注意,在此示例中,散点图中会显示所有点,但是,在每个时间步中都会更新颜色不透明度(颜色数组中的第4个维度)以获取动画。您也可以尝试更新点而不是颜色以获得更好的性能......

答案 1 :(得分:1)

我猜你的瓶颈是在动画的每一帧中调用ax.scatterax.set_xlim3d并且类似。

理想情况下,您应该调用scatter一次,然后使用散点图返回的对象及set_...函数(more details here)中的animate属性。< / p>

我无法弄清楚如何使用scatter,但如果您使用ax.plot(x, y, z, 'o'),则可以按照演示方法here进行操作。

x, y, z使用一些随机数据。它会像这样工作

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation
from numpy.random import random

# Number of particles
numP = 2
# Dimensions
DIM = 3
timesteps = 2000

x, y, z = random(timesteps), random(timesteps), random(timesteps)

# Attaching 3D axis to the figure
fig = plt.figure()
ax = p3.Axes3D(fig)

# Setting the axes properties
border = 1
ax.set_xlim3d([-border, border])
ax.set_ylim3d([-border, border])
ax.set_zlim3d([-border, border])
line = ax.plot(x[:1], y[:1], z[:1], 'o')[0]


def animate(i):
    global x, y, z, numP
    idx1 = numP*(i+1)
    # join x and y into single 2 x N array
    xy_data = np.c_[x[:idx1], y[:idx1]].T
    line.set_data(xy_data)
    line.set_3d_properties(z[:idx1])

ani = animation.FuncAnimation(fig, animate, frames=timesteps, interval=1, blit=False, repeat=False)
plt.show()