Python中具有不同维数的数组的矢量化计算

时间:2019-01-23 00:37:08

标签: python arrays numpy vectorization

在科学计算中,可以将3D字段离散化为F[nx, ny, nz],其中nxnynz是3个方向上的网格点数。在每一点上,假设我们都附加了n-by-n张量。因此,对于张量字段,我们可以使用5D数组表示T[n, n, nx, ny, nz]。可以将任意点[i, j, k]的张量选择为T[:, :, i, j, k]。如果我想为每个点计算非对角线元素的总和,我想使用代码

import numpy as np
r = np.zeros((nx, ny, nz)) 
for i in range(nx):
    for j in range(ny):
        for k in range(nz):
            r[i,j,k] = np.sum(T[:,:,i,j,k])-np.trace(T[:,:,i,j,k])

结果数组r和张量字段T具有不同的维度。在Python中,每个元素的循环计算效率都很低。还有其他方法可以对具有不同维数的数组进行矢量化或高效计算。还是可以使用其他数据类型/结构。

1 个答案:

答案 0 :(得分:4)

以下是两种不同的选择。第一个使用ndarray.sum和NumPy integer array indexing。第二种选择使用np.einsum

def using_sum(T):
    total = T.sum(axis=1).sum(axis=0)
    m = np.arange(T.shape[0])
    trace = T[m, m].sum(axis=0)
    return total - trace

def using_einsum(T):
    return np.einsum('mnijk->ijk', T) - np.einsum('nnijk->ijk', T)

np.einsum的第一个参数指定求和的下标。

'mnijk->ijk'表示T有下标mnijk,求和后仅剩下ijk下标。因此,对mn下标执行求和。这使得 np.einsum('mnijk->ijk', T)[i,j,k]等于np.sum(T[:,:,i,j,k]),但以一次矢量化计算来计算整个数组。

类似地,'nnijk->ijk'告诉np.einsum T有下标nnijk,并且只有ijk下标求和。因此,总和超过n。由于重复n,因此n上的总和将计算出轨迹。

我喜欢np.einsum,因为它传达了计算的意图 简洁地。但自从开始以来,熟悉using_sum的工作方式也很好 它使用基本的NumPy操作。这是嵌套循环的一个很好的例子 通过使用对整个数组进行操作的NumPy方法可以避免这种情况。


这里是perfplot,根据orig比较using_sumusing_einsumn的性能,其中T被用作形状为(10, 10, n, n, n)

import perfplot
import numpy as np

def orig(T):
    _, _, nx, ny, nz = T.shape
    r = np.zeros((nx, ny, nz)) 
    for i in range(nx):
        for j in range(ny):
            for k in range(nz):
                r[i,j,k] = np.sum(T[:,:,i,j,k])-np.trace(T[:,:,i,j,k])
    return r

def using_einsum(T):
    r = np.einsum('mnijk->ijk', T) - np.einsum('nnijk->ijk', T)
    return r

def using_sum(T):
    total = T.sum(axis=1).sum(axis=0)
    m = np.arange(T.shape[0])
    trace = T[m,m].sum(axis=0)
    return total - trace

def make_T(n):
    return np.random.random((10,10,n,n,n))

perfplot.show(
    setup=make_T,
    kernels=[orig, using_sum, using_einsum],
    n_range=range(2, 80, 3),
    xlabel='n')

enter image description here

perfplot.show还检查origusing_sumusing_einsum返回的值是否相等。