在循环中更新pyplot三维散点图,网格线重叠点

时间:2014-02-06 22:34:32

标签: python loops matplotlib scatter

我正在使用循环的每次迭代更新三维散点图。重绘图时,网格线“穿过”或“覆盖”点,这使得我的数据更难以可视化。如果我构建一个单独的3d图(没有循环更新),这不会发生。下面的代码演示了最简单的情况:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2])
plt.draw()

for i in range(0, 20):
    time.sleep(3)   #make changes more apparent/easy to see

    Y = np.random.rand(100, 3)*5
    ax.cla()    
    ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2])
    plt.draw()

还有其他人遇到过这个问题吗?

4 个答案:

答案 0 :(得分:3)

问题出在ax.cla()plt.cla()调用中,MaxNoe似乎是正确的。事实上它似乎像known issue

然后出现问题,因为清晰轴方法在3D绘图中不起作用,对于3D散射,没有干净的方法来改变数据点的坐标(la sc.set_data(new_values)),如上所述in this mail list(我最近没有发现任何事情)。

然而,在邮件列表中,Ben Roon指出了一个可能对您有用的解决方法。

解决方法:

您需要在_ofsets3d函数返回的Line3DCollection对象的内部scatter变量中设置数据点的新坐标。

你改编的例子如下:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2])
fig.show()

for i in range(0, 20):
    plt.pause(1)

    Y = np.random.rand(100, 3)*5

    sc._offsets3d = (Y[:,0], Y[:,1], Y[:,2])
    plt.draw()

答案 1 :(得分:2)

我可以将其缩小到使用cla():

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax.plot_surface(x,y, x**2+y**2)
fig.savefig("fig_a.png")

ax.cla()
ax.plot_surface(x,y, x**2+y**2)

fig.savefig("fig_b.png")

这些是结果图: fig_a fig_b

答案 2 :(得分:1)

这只是一种解决方法,因为它无法解决MaxNoe指出的ax.cla()问题。它也不是特别漂亮,因为它清除整个数字,但它完成了所需的任务:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig1 = plt.figure()
ax1 = fig1.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax1.plot_surface(x,y, x**2+y**2)
fig1.savefig("fig_a.png")

fig1.clf()
ax1 = fig1.add_subplot(111, projection='3d')
ax1.plot_surface(x,y, x**2+y**2)

fig1.savefig("fig_b.png")

答案 3 :(得分:-2)

我建议使用ax = fig.gca(projection='3d')代替ax = fig.add_subplot(111, projection='3d')