numpy向量化函数可重复连续元素的块

时间:2018-07-03 12:45:23

标签: python algorithm numpy vectorization

numpy具有аrepeat函数,该函数将数组的每个元素重复给定的(每)元素次数。

我想实现一个功能类似的功能,但不重复单个元素,而是重复大小可变的连续元素块。本质上,我需要以下功能:

import numpy as np

def repeat_blocks(a, sizes, repeats):
    b = []    
    start = 0
    for i, size in enumerate(sizes):
        end = start + size
        b.extend([a[start:end]] * repeats[i])
        start = end
    return np.concatenate(b)

例如,给定

a = np.arange(20)
sizes = np.array([3, 5, 2, 6, 4])
repeats = np.array([2, 3, 2, 1, 3])

然后

repeat_blocks(a, sizes, repeats)

返回

array([ 0,  1,  2, 
        0,  1,  2,

        3,  4,  5,  6,  7, 
        3,  4,  5,  6,  7, 
        3,  4,  5,  6,  7, 

        8,  9, 
        8,  9,

        10, 11, 12, 13, 14, 15,

        16, 17, 18, 19,
        16, 17, 18, 19,
        16, 17, 18, 19 ])

我想以性能的名义将这些循环推入numpy中。这可能吗?如果可以,怎么办?

2 个答案:

答案 0 :(得分:2)

这是一种使用cumsum的矢量化方法-

# Get repeats for each group using group lengths/sizes
r1 = np.repeat(np.arange(len(sizes)), repeats)

# Get total size of output array, as needed to initialize output indexing array
N = (sizes*repeats).sum() # or np.dot(sizes, repeats)

# Initialize indexing array with ones as we need to setup incremental indexing
# within each group when cumulatively summed at the final stage. 
# Two steps here:
# 1. Within each group, we have multiple sequences, so setup the offsetting
# at each sequence lengths by the seq. lengths preceeeding those.
id_ar = np.ones(N, dtype=int)
id_ar[0] = 0
insert_index = sizes[r1[:-1]].cumsum()
insert_val = (1-sizes)[r1[:-1]]

# 2. For each group, make sure the indexing starts from the next group's
# first element. So, simply assign 1s there.
insert_val[r1[1:] != r1[:-1]] = 1

# Assign index-offseting values
id_ar[insert_index] = insert_val

# Finally index into input array for the group repeated o/p
out = a[id_ar.cumsum()]

答案 1 :(得分:2)

此功能非常适合使用Numba进行加速:

@numba.njit
def repeat_blocks_jit(a, sizes, repeats):
    out = np.empty((sizes * repeats).sum(), a.dtype)
    start = 0
    oi = 0
    for i, size in enumerate(sizes):
        end = start + size
        for rep in range(repeats[i]):
            oe = oi + size
            out[oi:oe] = a[start:end]
            oi = oe
        start = end
    return out

这比Divakar的纯NumPy解决方案要快得多,并且更接近于原始代码。我根本没有努力优化它。请注意,np.dot()np.repeat()在这里不能使用,但是当所有代码都被编译时,这并不重要。

另外,由于它是njit的意思是“ nopython”模式,如果您要进行许多这样的调用,甚至可以使用@numba.njit(nogil=True)并获得多核加速。