完全矢量化的元素的有序的有序减法

时间:2019-02-12 13:28:18

标签: python python-3.x numpy numpy-ndarray

想象一个mxn数组a和一个1xn数组b,我们想从b中减去a,这样{{从b的第一个元素中减去1}},然后从a中减去最大值0和b-a[0],依此类推...

所以:

a[1]

所以我们想得到:x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) a = np.repeat(x, 100000).reshape(10, 100000) b = np.repeat(np.array([5]), 100000).reshape(1, 100000) ,重复100,000次。

我已经管理了下面的函数,该函数可以提供所需的结果:

[ 0,  0,  1,  4,  5,  6,  7,  8,  9, 10]

但是,它还没有完全向量化。所以:

def func(a, b):
    n = np.copy(a)
    m = np.copy(b)
    for i in range(len(n)):
        n[i] = np.where(n[i] >= m, n[i] - m, 0)
        m = np.maximum(0, m - a[i])
        if not m.any():
            return n
    return n

理想情况下,我想摆脱for循环,并使其尽可能地矢量化。

谢谢。

1 个答案:

答案 0 :(得分:0)

我认为您可以将功能向量化:

import numpy as np

def func_vec(a, b):
    ar = np.roll(a, 1, axis=0)
    ar[0] = 0
    ac = np.cumsum(ar, axis=0)
    bc = np.maximum(b - ac, 0)
    return np.maximum(a - bc, 0)

快速测试:

import numpy as np

def func(a, b):
    n = np.copy(a)
    m = np.copy(b)
    for i in range(len(n)):
        n[i] = np.where(n[i] >= m, n[i] - m, 0)
        m = np.maximum(0, m - a[i])
        if not m.any():
            return n
    return n

np.random.seed(100)
n = 100000
m = 10
num_max = 100
a = np.random.randint(num_max, size=(m, n))
b = np.random.randint(num_max, size=(1, n))
print(np.all(func(a, b) == func_vec(a, b)))
# True

但是,与矢量化算法相比,您的算法具有一个重要优势,那就是当发现没有其他可减去的内容时,它将停止迭代。这意味着,根据问题的大小和特定值(确定提前停止发生的时间,如果确定的话),矢量化的解决方案实际上可能会变慢。请参阅上面的示例:

%timeit func(a, b)
# 5.09 ms ± 78.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit func_vec(a, b)
# 12.4 ms ± 939 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

但是,您可以使用Numba获得“两全其美”的解决方案:

import numpy as np
import numba as nb

@nb.njit
def func_nb(a, b):
    n = np.copy(a)
    m = np.copy(b)
    zero = np.array(0, dtype=a.dtype)
    for i in range(len(n)):
        n[i] = np.maximum(n[i] - m, zero)
        m = np.maximum(zero, m - a[i])
        if not m.any():
            return n
    return n

print(np.all(func(a, b) == func_nb(a, b)))
# True
%timeit func_nb(a, b)
# 3.36 ms ± 461 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)