嵌套,循环和条件累加的列表理解

时间:2020-08-05 21:37:31

标签: python numpy list-comprehension

我正在尝试将这段代码转换为列表理解:

a = np.random.rand(10) #input vector
n = len(a) # element count of input vector
b = np.random.rand(3) #coefficient vector
nb = len(b) #element count of coefficients
d = nb #decimation factor (could be any integer < len(a))
 
c = []
for i in range(0, n, d):
    psum = 0
    for j in range(nb):
        if i + j < n:
            psum += a[i + j]*b[j]
    c.append(psum)

我尝试了以下建议:

例如:

from itertools import accumulate
c = [accumulate([a[i + j] * b[j] for j in range(nb) if i + j < n] ) for i in range(0, n, d)]

稍后,当尝试从c(例如c[:index])获取值时:

TypeError: 'NoneType' object is not subscriptable

或者:

from functools import partial
def get_val(a, b, i, j, n):
    if i + j < n:
        return(a[i + j] * b[j])
    else:
        return(0)
c = [
         list(map(partial(get_val, i=i, j=j, n=n), a, b)) 
             for i in range(0, n, d) 
             for j in range(nb)
    ]

get_val中,返回(a [i + j] * b [j])

IndexError: invalid index to scalar variable.

或者:

psum_pieces = [[a[i + j] * b[j] if i + j < n else 0 for j in range(nb)] for i in range(0, n, d)]
c = [sum(psum) for psum in psum_pieces]

以及这些方法的许多其他迭代。任何指导将不胜感激。

2 个答案:

答案 0 :(得分:1)

如果我正确理解了您想要的东西

res = [sum(a[i+j]*b[j] for j in range(nb) if i+j < n) for i in range(0,n,d)]

对于每个i,这将在a[i+j]*b[j]中乘积j的乘积之和0nb-1i+j < n {1}}

答案 1 :(得分:1)

您真的不需要在这里使用列表理解功能。使用numpy,您可以创建一个快速的流水线解决方案,该解决方案不直接在解释器中运行任何循环。

首先将a转换为形状为(n // d, nb)的2D数组。缺少的元素(例如,循环中的i + j >= n可以为零,因为这会使psum的相应增量为零:

# pre-compute i+j as a 2D array
indices = np.arange(nb) + np.arange(0, n, d)[:, None]
# we only want valid locations
mask = indices < n

t = np.zeros(indices.shape)
t[mask] = a[indices[mask]]

现在您可以直接将c计算为

(t * b).sum(axis=1)

我怀疑,如果您将此解决方案与未使用numba编译的用vanilla python编写的任何代码进行基准测试,它将更快。