将任意数量的数组的所有可能组合相加并应用限制

时间:2016-11-03 20:34:59

标签: python arrays numpy combinations itertools

我正在尝试生成任意数量的数组的所有组合的数组。从生成的数组中,我想添加一个约束,数字的总和必须位于两个边界之间(比如'lower'和'upper')

这样做的一种方法是使用cartersian,对元素求和,并选择属于下限和上限的元素。但是,主要限制是在给定大量输入数组的情况下可能会耗尽内存。另一种方法是使用itertools.product:

import itertools
import numpy as np

def arraysums(arrays,lower,upper):
    p = itertools.product(*arrays)
    r = list()

    for n in p:
        s = sum(n)
        if lower <= s <= upper:
            r.append(n)

    return r

N = 8
a = np.arange(N)
b = np.arange(N)-N/2

arraysums((a,b),lower=5,upper=6)

返回如下结果:

[(2, 3),
 (3, 2),
 (3, 3),
 (4, 1),
 (4, 2),
 (5, 0),
 (5, 1),
 (6, -1),
 (6, 0),
 (7, -2),
 (7, -1)]

此方法具有内存效率,但如果数组很大,则可能会非常慢,例如此示例在10分钟内运行:

a = np.arange(32.)
arraysums(6*(a,),lower=10,upper=20)

我正在寻找一种更快的方法。

2 个答案:

答案 0 :(得分:3)

你可以使用递归。例如,如果从第一个数组中选择了item,则其余数组的新下限和上限应为lower-itemupper-item

这里的主要优点是你可以在每个阶段短路枚举元组。考虑所有值都是正值的情况。然后我们可以 自动抛出大于的其他数组中的任何值 upper-item。这会智能地减少每个搜索空间的大小 递归的程度。

import itertools

def arraysums_recursive_all_positive(arrays, lower, upper):
    # Assumes all values in arrays are positive
    if len(arrays) <= 1:
        result = [(item,) for item in arrays[0] if lower <= item <= upper]
    else:
        result = []
        for item in arrays[0]:
            subarrays = [[item2 for item2 in arr if item2 <= upper-item] 
                      for arr in arrays[1:]]
            if min(len(arr) for arr in subarrays) == 0:
                continue
            result.extend(
                [(item,)+tup for tup in arraysums_recursive_all_positive(
                    subarrays, lower-item, upper-item)])
    return result

def arraysums(arrays,lower,upper):
    p = itertools.product(*arrays)
    r = list()

    for n in p:
        s = sum(n)
        if lower <= s <= upper:
            r.append(n)

    return r

a = list(range(32))

对于此测试用例,arraysums_recursive_all_positivearraysums快了688倍:

In [227]: %time arraysums_recursive_all_positive(6*(a,),lower=10,upper=20)
CPU times: user 360 ms, sys: 8.01 ms, total: 368 ms
Wall time: 367 ms

In [73]: %time arraysums(6*(a,),lower=10,upper=20)
CPU times: user 4min 8s, sys: 0 ns, total: 4min 8s
Wall time: 4min 8s

在一般情况下,当arrays中的值可能为负值时,我们可以为arrays中的每个值添加适当的金额,以保证新arrays中的所有值是积极的。我们还可以调整lowerupper限制,以说明价值的这种变化。因此,我们可以将一般问题减少到具有所有正值的arrays的特殊情况:

def arraysums_recursive(arrays, lower, upper):
    minval = min(item for arr in arrays for item in arr)
    # Subtract minval from arrays to guarantee all the values are positive
    arrays = [[item-minval for item in arr] for arr in arrays]
    # Adjust the lower and upper bounds accordingly
    lower -= minval*len(arrays)
    upper -= minval*len(arrays)
    result = arraysums_recursive_all_positive(arrays, lower, upper)
    # Readjust the result by adding back minval
    result = [tuple([item+minval for item in tup]) for tup in result]
    return result

请注意,arraysums_recursive正确处理负值,而 arraysums_recursive_all_positive没有:

In [312]: arraysums_recursive([[10,30],[20,40],[-35,-40]],lower=10,upper=20)
Out[312]: [(10, 40, -35), (10, 40, -40), (30, 20, -35), (30, 20, -40)]

In [311]: arraysums_recursive_all_positive([[10,30],[20,40],[-35,-40]],lower=10,upper=20)
Out[311]: []

虽然arraysums_recursivearraysums_recursive_all_positive慢,但

In [37]: %time arraysums_recursive(6*(a,),lower=10,upper=20)
CPU times: user 1.03 s, sys: 0 ns, total: 1.03 s
Wall time: 852 ms

它仍然比arraysums快290倍。

答案 1 :(得分:1)

这是一种利用NumPy broadcasting -

的矢量化方法
def arraysums_vectorized(arrays,lower,upper):
    a,b = arrays
    sums = a[:,None] + b
    r,c = np.nonzero((lower <= sums) & (sums <= upper))
    return np.column_stack((a[r], b[c]))

运行时测试 -

In [2]: # Inputs
   ...: N = 800
   ...: a = np.arange(N)
   ...: b = np.arange(N)-N/2
   ...: 

In [3]: l = 500
   ...: u = 600
   ...: out1 = arraysums((a,b),lower=l,upper=u)
   ...: out2 = arraysums_vectorized((a,b),lower=l,upper=u)
   ...: print np.allclose(out1,out2)
   ...: 
True

In [4]: %timeit arraysums((a,b),lower=l,upper=u)
1 loops, best of 3: 508 ms per loop

In [5]: %timeit arraysums_vectorized((a,b),lower=l,upper=u)
100 loops, best of 3: 7.11 ms per loop