将函数应用于ndarray的0维

时间:2016-07-12 17:08:21

标签: python function numpy multidimensional-array vectorization

问题

  • 我有ndarray,由arr定义,是一个n维立方体,每个维度的长度为m

  • 我希望通过沿维度func切片并将每个n=0 - 昏暗切片作为函数的输入来执行函数n-1

这似乎适用于map(),但我找不到合适的numpy变体。 np.vectorise似乎将n-1 - 张量分割为单个标量条目。 apply_along_axisapply_over_axes似乎也不合适。

我的问题是我需要将任意函数作为输入传递,因此我看不到einsum可行的解决方案。

问题

  • 您是否知道使用numpy的最佳np.asarray(map(func, arr))替代方案?

实施例

我通过以下方式将示例数组arr定义为4 - dim cube(或4-tensor):

m, n = 3, 4 
arr = np.arange(m**n).reshape((m,)*n)

我定义了一个示例函数f

def f(x):
    """makes it obvious how the np.ndarray is being passed into the function"""
    try: # perform an op using x[0,0,0] which is expected to exist
        i = x[0,0,0]
    except:
        print '\nno element x[0,0,0] in x: \n{}'.format(x)
        return np.nan
    return x-x+i

此函数的预期结果res将保持相同的形状,但会满足以下条件:

print all([(res[i] == i*m**(n-1)).all() for i in range(m)])

这适用于默认的map()函数

res = np.asarray(map(f, a))
print all([(res[i] == i*m**(n-1)).all() for i in range(m)])
True

我希望np.vectorize的工作方式与map()相同,但它会在标量条目中起作用:

res = np.vectorize(f)(a)

no element x[0,0,0] in x: 
0
...

1 个答案:

答案 0 :(得分:2)

鉴于arr为4d,而您的fn适用于3d数组,

np.asarray(map(func, arr))

看起来非常合理。我会使用列表理解表单,但这是编程风格的问题

np.asarray([func(i) for i in arr])

for i in arr遍历arr的第一维。实际上,它将arr视为3d数组的列表。然后它将结果列表重新组合成一个4d数组。

np.vectorize doc可以更明确地说明使用标量的函数。但是,是的,它将值传递为标量。请注意,np.vectorize没有提供传递迭代轴参数的规定。当你的函数从多个数组中获取值时,它是最有用的,比如

 [func(a,b) for a,b in zip(arrA, arrB)]

它概括了zip所以允许广播。但否则它是一个迭代的解决方案。它对func的内容一无所知,因此无法加快其通话速度。

np.vectorize最终会调用np.frompyfunc,这有点不那么通用会慢一点。但它也将标量传递给了func。

np.apply_along/over_ax(e/i)s也会迭代一个或多个轴。您可能会发现他们的代码具有指导性,但我同意他们不适用于此。

地图方法的一个变体是分配结果数组和索引:

In [45]: res=np.zeros_like(arr,int)
In [46]: for i in range(arr.shape[0]):
    ...:     res[i,...] = f(arr[i,...])

如果您需要在与第1个轴不同的轴上进行迭代,这可能会更容易。

你需要做自己的时间,看看哪个更快。

========================

使用就地修改对第一维进行迭代的示例:

In [58]: arr.__array_interface__['data']  # data buffer address
Out[58]: (152720784, False)

In [59]: for i,a in enumerate(arr):
    ...:     print(a.__array_interface__['data'])
    ...:     a[0,0,:]=i
    ...:     
(152720784, False)   # address of the views (same buffer)
(152720892, False)
(152721000, False)

In [60]: arr
Out[60]: 
array([[[[ 0,  0,  0],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        ...

       [[[ 1,  1,  1],
         [30, 31, 32],
         ...

       [[[ 2,  2,  2],
         [57, 58, 59],
         [60, 61, 62]],
       ...]]])

当我遍历一个数组时,我得到一个从公共数据缓冲区上的连续点开始的视图。如果我修改视图,如上所述或甚至修改a[:]=...,我修改原始视图。我不需要写任何东西。但是不要使用a = ....,它会破坏原始数组的链接。