我有一个3D ndarry对象,它包含光谱数据(即空间xy维度和能量维度)。我想从线图中提取并绘制每个像素的光谱。目前,我正在沿着我感兴趣的轴使用np.ndenumerate这样做,但它很慢。我希望尝试np.apply_along_axis
,看看它是否更快,但我不断收到一个奇怪的错误。
什么有效:
# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt
ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume
# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel)
# on a single line plot:
for (x,y) in np.ndenumerate(ar[:,:,1]):
plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')
据我所知,这基本上是一个循环,效率低于基于数组的方法,因此我想尝试使用np.apply_along_axis
的方法,看看它是否更快。这是我对python的第一次尝试,然而,我仍然在发现它是如何工作的,所以如果这个想法存在根本缺陷,请把我弄好!
我想尝试一下:
# define a function to pass to apply_along_axis
def pa(y,x):
if ~all(np.isnan(y)): # only do the plot if there is actually data there...
plt.plot(x,y,alpha=0.15,color='black')
return
# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)
# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!
产生的错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
12 # pa(ar[1,1,:],ax)
13
---> 14 np.apply_along_axis(pa,2,ar,ax)
//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
101 holdshape = outshape
102 outshape = list(arr.shape)
--> 103 outshape[axis] = len(res)
104 outarr = zeros(outshape, asarray(res).dtype)
105 outarr[tuple(i.tolist())] = res
TypeError: object of type 'NoneType' has no len()
任何想法在这里出错/关于如何做得更好的建议都会很棒。
谢谢!
答案 0 :(得分:2)
apply_along_axis
从您的函数输出中创建一个新数组。
您正在返回None
(不返回任何内容)。因此错误。 Numpy检查返回输出的长度,看看它对新数组是否有意义。
因为您没有从结果中构建新数组,所以没有理由使用apply_along_axis
。它不会更快。
但是,您当前的ndenumerate
语句完全等同于:
import numpy as np
import matplotlib.pyplot as plt
ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')
一般情况下,您可能希望执行以下操作:
for pixel in ar.reshape(-1, ar.shape[-1]):
plt.plot(x_values, pixel, ...)
通过这种方式,您可以轻松地迭代高光谱阵列中每个像素的光谱。
这里你的瓶颈可能不是你如何使用阵列。 <{1}}中使用相同参数分别绘制每一行的效果会有些低效。
构建需要稍长的时间,但matplotlib
渲染速度要快得多。 (基本上,使用LineCollection
告诉matplotlib不要检查每行的属性,并将它们全部传递给低级渲染器以便以相同的方式绘制。你绕过了一堆个人{ {1}}呼吁支持大型对象的LineCollection
个。{/ p>
缺点是,代码的可读性会低一些。
我稍后会添加一个例子。