在Python中绘制3d数组的最有效方法是什么?
例如:
volume = np.random.rand(512, 512, 512)
其中数组项表示每个像素的灰度颜色。
以下代码运行速度太慢:
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.gca(projection='3d')
volume = np.random.rand(20, 20, 20)
for x in range(len(volume[:, 0, 0])):
for y in range(len(volume[0, :, 0])):
for z in range(len(volume[0, 0, :])):
ax.scatter(x, y, z, c = tuple([volume[x, y, z], volume[x, y, z], volume[x, y, z], 1]))
plt.show()
答案 0 :(得分:4)
为了获得更好的性能,请尽可能避免多次调用ax.scatter
。
相反,将所有x
,y
,z
坐标和颜色打包到1D数组中(或
列表),然后拨打ax.scatter
一次:
ax.scatter(x, y, z, c=volume.ravel())
问题(就CPU时间和内存而言)增长为size**3
,其中size
是多维数据集的边长。
此外,ax.scatter
会尝试渲染所有size**3
点而不考虑
事实上,大多数这些点被外围的人遮挡了
外壳
这有助于减少volume
中的点数 - 也许是
以某种方式对其进行汇总或重新采样/插值 - 在渲染之前。
我们还可以将O(size**3)
到O(size**2)
所需的CPU和内存减少
只绘制外壳:
import functools
import itertools as IT
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def cartesian_product_broadcasted(*arrays):
"""
http://stackoverflow.com/a/11146645/190597 (senderle)
"""
broadcastable = np.ix_(*arrays)
broadcasted = np.broadcast_arrays(*broadcastable)
dtype = np.result_type(*arrays)
rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
out = np.empty(rows * cols, dtype=dtype)
start, end = 0, rows
for a in broadcasted:
out[start:end] = a.reshape(-1)
start, end = end, end + rows
return out.reshape(cols, rows).T
# @profile # used with `python -m memory_profiler script.py` to measure memory usage
def main():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
size = 512
volume = np.random.rand(size, size, size)
x, y, z = cartesian_product_broadcasted(*[np.arange(size, dtype='int16')]*3).T
mask = ((x == 0) | (x == size-1)
| (y == 0) | (y == size-1)
| (z == 0) | (z == size-1))
x = x[mask]
y = y[mask]
z = z[mask]
volume = volume.ravel()[mask]
ax.scatter(x, y, z, c=volume, cmap=plt.get_cmap('Greys'))
plt.show()
if __name__ == '__main__':
main()
size=512
我们仍然需要大约1.3 GiB的内存。还要注意即使你有足够的总内存,但由于缺少RAM,程序使用交换空间,那么程序的整体速度将
急剧减速。如果您发现自己处于这种情况,那么唯一的解决方案就是找到一种更智能的方法来使用更少的点渲染可接受的图像,或购买更多的RAM。
答案 1 :(得分:2)
首先,512x512x512点的密集网格绘制的数据太多,不是从技术角度来看,而是在观察绘图时能够看到任何有用的数据。您可能需要提取一些等值面,查看切片等。如果大多数点是不可见的,那么它可能没问题,但是您应该要求ax.scatter
仅显示非零点以使其更快。
那就是说,这就是你如何能够更快地做到这一点。这些技巧是为了消除所有Python循环,包括隐藏在itertools
等库中的循环。
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
# Make this bigger to generate a dense grid.
N = 8
# Create some random data.
volume = np.random.rand(N, N, N)
# Create the x, y, and z coordinate arrays. We use
# numpy's broadcasting to do all the hard work for us.
# We could shorten this even more by using np.meshgrid.
x = np.arange(volume.shape[0])[:, None, None]
y = np.arange(volume.shape[1])[None, :, None]
z = np.arange(volume.shape[2])[None, None, :]
x, y, z = np.broadcast_arrays(x, y, z)
# Turn the volumetric data into an RGB array that's
# just grayscale. There might be better ways to make
# ax.scatter happy.
c = np.tile(volume.ravel()[:, None], [1, 3])
# Do the plotting in a single call.
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(),
y.ravel(),
z.ravel(),
c=c)