如何在Python中有效地获得总和为10或更低的所有组合

时间:2014-12-21 03:13:31

标签: python itertools

想象一下,您正尝试在某些地区(例如n=10)分配一些固定资源(例如t=5)。我试图有效地找出如何获得总和为n或更低的所有组合。

E.g。 10,0,0,0,0很好,0,0,5,5,0等等,而3,3,3,3,3,3显然是错误的。

我到目前为止:

import itertools
t = 5
n = 10
r = [range(n+1)] * t
for x in itertools.product(*r): 
   if sum(x) <= n:          
       print x

这种蛮力方法虽然令人难以置信地缓慢;必须有更好的方法吗?

计时(1000次迭代):

Default (itertools.product)           --- time: 40.90 s
falsetru recursion                    --- time:  3.63 s
Aaron Williams Algorithm (impl, Tony) --- time:  0.37 s

4 个答案:

答案 0 :(得分:3)

创建自己的递归函数,除非可以得到一个总和&lt; = 10,否则不会使用元素递归。

def f(r, n, t, acc=[]):
    if t == 0:
        if n >= 0:
            yield acc
        return
    for x in r:
        if x > n:  # <---- do not recurse if sum is larger than `n`
            break
        for lst in f(r, n-x, t-1, acc + [x]):
            yield lst

t = 5
n = 10
for xs in f(range(n+1), n, 5):
    print xs

答案 1 :(得分:2)

可能的方法如下。绝对会谨慎使用(根本没有测试,但n = 10和t = 5的结果看起来合理)。

该方法涉及 no 递归。生成具有m个元素(在您的示例中为5)的数字n(在您的示例中为10)的分区的算法来自Knuth的第4卷。如果需要,每个分区然后进行零扩展,并且使用来自Aaron Williams的算法生成所有不同的排列,我已经看到elsewhere。两种算法都必须转换为Python,这增加了错误悄悄进入的机会。威廉姆斯算法想要一个链表,我不得不用2D数组伪造,以避免编写链表类。

有一个下午!

代码(请注意,您的n是我的maxntp}:

import itertools

def visit(a, m):
    """ Utility function to add partition to the list"""
    x.append(a[1:m+1])

def parts(a, n, m):
    """ Knuth Algorithm H, Combinatorial Algorithms, Pre-Fascicle 3B
        Finds all partitions of n having exactly m elements.
        An upper bound on running time is (3 x number of
        partitions found) + m.  Not recursive!      
    """
    while (1):
        visit(a, m)
        while a[2] < a[1]-1:
            a[1] -= 1
            a[2] += 1
            visit(a, m)
        j=3
        s = a[1]+a[2]-1
        while a[j] >= a[1]-1:
            s += a[j]
            j += 1
        if j > m:
            break
        x = a[j] + 1
        a[j] = x
        j -= 1
        while j>1:
            a[j] = x
            s -= x
            j -= 1
            a[1] = s

def distinct_perms(partition):
    """ Aaron Williams Algorithm 1, "Loopless Generation of Multiset
        Permutations by Prefix Shifts".  Finds all distinct permutations
        of a list with repeated items.  I don't follow the paper all that
        well, but it _possibly_ has a running time which is proportional
        to the number of permutations (with 3 shift operations for each  
        permutation on average).  Not recursive!
    """

    perms = []
    val = 0
    nxt = 1
    l1 = [[partition[i],i+1] for i in range(len(partition))]
    l1[-1][nxt] = None
    #print(l1)
    head = 0
    i = len(l1)-2
    afteri = i+1
    tmp = []
    tmp += [l1[head][val]]
    c = head
    while l1[c][nxt] != None:
        tmp += [l1[l1[c][nxt]][val]]
        c = l1[c][nxt]
    perms.extend([tmp])
    while (l1[afteri][nxt] != None) or (l1[afteri][val] < l1[head][val]):
        if (l1[afteri][nxt] != None) and (l1[i][val]>=l1[l1[afteri][nxt]][val]):
            beforek = afteri
        else:
            beforek = i
        k = l1[beforek][nxt]
        l1[beforek][nxt] = l1[k][nxt]
        l1[k][nxt] = head
        if l1[k][val] < l1[head][val]:
            i = k
        afteri = l1[i][nxt]
        head = k
        tmp = []
        tmp += [l1[head][val]]
        c = head
        while l1[c][nxt] != None:
            tmp += [l1[l1[c][nxt]][val]]
            c = l1[c][nxt]
        perms.extend([tmp])

    return perms

maxn = 10 # max integer to find partitions of
p = 5  # max number of items in each partition

# Find all partitions of length p or less adding up
# to maxn or less

# Special cases (Knuth's algorithm requires n and m >= 2)
x = [[i] for i in range(maxn+1)]
# Main cases: runs parts fn (maxn^2+maxn)/2 times
for i in range(2, maxn+1):
    for j in range(2, min(p+1, i+1)):
        m = j
        n = i
        a = [0, n-m+1] + [1] * (m-1) + [-1] + [0] * (n-m-1)
        parts(a, n, m)
y = []
# For each partition, add zeros if necessary and then find
# distinct permutations.  Runs distinct_perms function once
# for each partition.
for part in x:
    if len(part) < p:
        y += distinct_perms(part + [0] * (p - len(part)))
    else:
        y += distinct_perms(part)
print(y)
print(len(y))

答案 2 :(得分:2)

您可以使用itertools创建所有排名,并使用numpy解析结果。

>>> import numpy as np
>>> from itertools import product

>>> t = 5
>>> n = 10
>>> r = range(n+1)

# Create the product numpy array
>>> prod = np.fromiter(product(r, repeat=t), np.dtype('u1,' * t))
>>> prod = prod.view('u1').reshape(-1, t)

# Extract only permutations that satisfy a condition
>>> prod[prod.sum(axis=1) < n]

Timeit:

>>> %%timeit 
    prod = np.fromiter(product(r, repeat=t), np.dtype('u1,' * t))
    prod = prod.view('u1').reshape(-1, t)
    prod[prod.sum(axis=1) < n]

10 loops, best of 3: 41.6 ms per loop

您甚至可以通过populating combinations directly in numpy来加速产品计算。

答案 3 :(得分:0)

您可以使用动态编程优化算法。

基本上,有一个数组a,其中a[i][j]表示“我可以获得j的总和,其中包含j-th元素之前的元素(并使用{ {1}}元素,假设您的元素在数组jth中(不是您提到的数字)。)

然后你可以填充数组

t

然后,使用此信息,您可以回溯解决方案:)

a[0][t[0]] = True
for i in range(1, len(t)):
    a[i][t[i]] = True
    for j in range(t[i]+1, n+1):
         for k in range(0, i):
             if a[k][j-t[i]]:
                 a[i][j] = True