基于索引函数的数组的矢量化操作

时间:2014-01-24 23:23:33

标签: python arrays numpy vectorization

我有一个表示3D点之间的函数的数组。因此,作为索引,它获得6元组。现在我想对这个数组的元素应用一个函数,但是这个函数不仅取决于元素的值,还取决于它的索引。因此,如果A是矩阵,并且m和n是我们的3D点,A [m,n]存储其值,k是0到3之间的值,则f(A,k)[m,n]等于:

- m[k]**2如果m==n

- m[k]**2-n[k]**2否则

以下是我的代码:

import numpy as np
def func(a,k):
    b=np.empty(a.shape)
    for i in range(a.flatten().size):
        ind=np.unravel_index(i,a.shape)
        if ind[0:3]==ind[3:6]:
            b[ind]=a[ind]*ind[0:3][k]**2
        else:
            b[ind]=a[ind]*(ind[0:3][k]**2-ind[3:6][k]**2)
    return b
a=np.arange(729).reshape((3,3,3,3,3,3))
print func(a,2)

无论如何都要对这段代码进行归档吗?

P.S。这是我实际需要做的简化版本。

1 个答案:

答案 0 :(得分:2)

使用numpy.indices()创建索引数组,然后就可以进行计算:

import numpy as np
def func(a,k):
    b=np.empty(a.shape)
    for i in range(a.flatten().size):
        ind=np.unravel_index(i,a.shape)
        if ind[0:3]==ind[3:6]:
            b[ind]=a[ind]*ind[0:3][k]**2
        else:
            b[ind]=a[ind]*(ind[0:3][k]**2-ind[3:6][k]**2)
    return b

def func2(a,k):
    b = np.empty(a.shape)
    ind = np.indices(a.shape).reshape(6, -1)
    mask = np.all(ind[:3] == ind[3:6], axis=0)
    ar = a.ravel()
    br = b.ravel()
    br[mask] = ar[mask]*ind[k, mask]**2
    mask = ~mask
    br[mask] = ar[mask]*(ind[k, mask]**2 - ind[3+k, mask]**2)
    return b

a = np.arange(729).reshape((3,3,3,3,3,3))
b1 = func(a, 2)
b2 = func2(a, 2)
np.allclose(b1, b2)

这是%timeit结果:

%timeit func(a, 2)
%timeit func2(a, 2)

输出:

100 loops, best of 3: 16.4 ms per loop
1000 loops, best of 3: 579 µs per loop

您可以针对您的情况稍微优化一下:

def func3(a,k):
    b = np.empty(a.shape)
    ind = np.indices(a.shape).reshape(6, -1)
    mask = ~np.all(ind[:3] == ind[3:6], axis=0)
    ar = a.ravel()
    br = b.ravel()
    br[:] = ar*ind[k]**2
    br[mask] -= ar[mask]*ind[3+k, mask]**2
    return b