使用apply_along_axis绘图

时间:2014-03-16 01:42:20

标签: python numpy matplotlib scipy

我有一个3D ndarry对象,它包含光谱数据(即空间xy维度和能量维度)。我想从线图中提取并绘制每个像素的光谱。目前,我正在沿着我感兴趣的轴使用np.ndenumerate这样做,但它很慢。我希望尝试np.apply_along_axis,看看它是否更快,但我不断收到一个奇怪的错误。

什么有效:

# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt

ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume

# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel) 
# on a single line plot:

for (x,y) in np.ndenumerate(ar[:,:,1]):
    plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')

据我所知,这基本上是一个循环,效率低于基于数组的方法,因此我想尝试使用np.apply_along_axis的方法,看看它是否更快。这是我对python的第一次尝试,然而,我仍然在发现它是如何工作的,所以如果这个想法存在根本缺陷,请把我弄好!

我想尝试一下:

# define a function to pass to apply_along_axis
def pa(y,x):
    if ~all(np.isnan(y)): # only do the plot if there is actually data there...
        plt.plot(x,y,alpha=0.15,color='black')
    return

# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)

# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!

产生的错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
     12 # pa(ar[1,1,:],ax)
     13 
---> 14 np.apply_along_axis(pa,2,ar,ax)

//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
    101         holdshape = outshape
    102         outshape = list(arr.shape)
--> 103         outshape[axis] = len(res)
    104         outarr = zeros(outshape, asarray(res).dtype)
    105         outarr[tuple(i.tolist())] = res

TypeError: object of type 'NoneType' has no len()

任何想法在这里出错/关于如何做得更好的建议都会很棒。

谢谢!

1 个答案:

答案 0 :(得分:2)

apply_along_axis从您的函数输出中创建一个新数组

您正在返回None(不返回任何内容)。因此错误。 Numpy检查返回输出的长度,看看它对新数组是否有意义。

因为您没有从结果中构建新数组,所以没有理由使用apply_along_axis。它不会更快。

但是,您当前的ndenumerate语句完全等同于:

import numpy as np
import matplotlib.pyplot as plt

ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')

一般情况下,您可能希望执行以下操作:

for pixel in ar.reshape(-1, ar.shape[-1]):
    plt.plot(x_values, pixel, ...)

通过这种方式,您可以轻松地迭代高光谱阵列中每个像素的光谱。


这里你的瓶颈可能不是你如何使用阵列。 <{1}}中使用相同参数分别绘制每一行的效果会有些低效。

构建需要稍长的时间,但matplotlib渲染速度要快得多。 (基本上,使用LineCollection告诉matplotlib不要检查每行的属性,并将它们全部传递给低级渲染器以便以相同的方式绘制。你绕过了一堆个人{ {1}}呼吁支持大型对象的LineCollection个。{/ p>

缺点是,代码的可读性会低一些。

我稍后会添加一个例子。