想象一下,您正尝试在某些地区(例如n=10
)分配一些固定资源(例如t=5
)。我试图有效地找出如何获得总和为n
或更低的所有组合。
E.g。 10,0,0,0,0
很好,0,0,5,5,0
等等,而3,3,3,3,3,3
显然是错误的。
我到目前为止:
import itertools
t = 5
n = 10
r = [range(n+1)] * t
for x in itertools.product(*r):
if sum(x) <= n:
print x
这种蛮力方法虽然令人难以置信地缓慢;必须有更好的方法吗?
计时(1000次迭代):
Default (itertools.product) --- time: 40.90 s
falsetru recursion --- time: 3.63 s
Aaron Williams Algorithm (impl, Tony) --- time: 0.37 s
答案 0 :(得分:3)
创建自己的递归函数,除非可以得到一个总和&lt; = 10,否则不会使用元素递归。
def f(r, n, t, acc=[]):
if t == 0:
if n >= 0:
yield acc
return
for x in r:
if x > n: # <---- do not recurse if sum is larger than `n`
break
for lst in f(r, n-x, t-1, acc + [x]):
yield lst
t = 5
n = 10
for xs in f(range(n+1), n, 5):
print xs
答案 1 :(得分:2)
可能的方法如下。绝对会谨慎使用(根本没有测试,但n = 10和t = 5的结果看起来合理)。
该方法涉及 no 递归。生成具有m个元素(在您的示例中为5)的数字n(在您的示例中为10)的分区的算法来自Knuth的第4卷。如果需要,每个分区然后进行零扩展,并且使用来自Aaron Williams的算法生成所有不同的排列,我已经看到elsewhere。两种算法都必须转换为Python,这增加了错误悄悄进入的机会。威廉姆斯算法想要一个链表,我不得不用2D数组伪造,以避免编写链表类。
有一个下午!
代码(请注意,您的n
是我的maxn
,t
是p
}:
import itertools
def visit(a, m):
""" Utility function to add partition to the list"""
x.append(a[1:m+1])
def parts(a, n, m):
""" Knuth Algorithm H, Combinatorial Algorithms, Pre-Fascicle 3B
Finds all partitions of n having exactly m elements.
An upper bound on running time is (3 x number of
partitions found) + m. Not recursive!
"""
while (1):
visit(a, m)
while a[2] < a[1]-1:
a[1] -= 1
a[2] += 1
visit(a, m)
j=3
s = a[1]+a[2]-1
while a[j] >= a[1]-1:
s += a[j]
j += 1
if j > m:
break
x = a[j] + 1
a[j] = x
j -= 1
while j>1:
a[j] = x
s -= x
j -= 1
a[1] = s
def distinct_perms(partition):
""" Aaron Williams Algorithm 1, "Loopless Generation of Multiset
Permutations by Prefix Shifts". Finds all distinct permutations
of a list with repeated items. I don't follow the paper all that
well, but it _possibly_ has a running time which is proportional
to the number of permutations (with 3 shift operations for each
permutation on average). Not recursive!
"""
perms = []
val = 0
nxt = 1
l1 = [[partition[i],i+1] for i in range(len(partition))]
l1[-1][nxt] = None
#print(l1)
head = 0
i = len(l1)-2
afteri = i+1
tmp = []
tmp += [l1[head][val]]
c = head
while l1[c][nxt] != None:
tmp += [l1[l1[c][nxt]][val]]
c = l1[c][nxt]
perms.extend([tmp])
while (l1[afteri][nxt] != None) or (l1[afteri][val] < l1[head][val]):
if (l1[afteri][nxt] != None) and (l1[i][val]>=l1[l1[afteri][nxt]][val]):
beforek = afteri
else:
beforek = i
k = l1[beforek][nxt]
l1[beforek][nxt] = l1[k][nxt]
l1[k][nxt] = head
if l1[k][val] < l1[head][val]:
i = k
afteri = l1[i][nxt]
head = k
tmp = []
tmp += [l1[head][val]]
c = head
while l1[c][nxt] != None:
tmp += [l1[l1[c][nxt]][val]]
c = l1[c][nxt]
perms.extend([tmp])
return perms
maxn = 10 # max integer to find partitions of
p = 5 # max number of items in each partition
# Find all partitions of length p or less adding up
# to maxn or less
# Special cases (Knuth's algorithm requires n and m >= 2)
x = [[i] for i in range(maxn+1)]
# Main cases: runs parts fn (maxn^2+maxn)/2 times
for i in range(2, maxn+1):
for j in range(2, min(p+1, i+1)):
m = j
n = i
a = [0, n-m+1] + [1] * (m-1) + [-1] + [0] * (n-m-1)
parts(a, n, m)
y = []
# For each partition, add zeros if necessary and then find
# distinct permutations. Runs distinct_perms function once
# for each partition.
for part in x:
if len(part) < p:
y += distinct_perms(part + [0] * (p - len(part)))
else:
y += distinct_perms(part)
print(y)
print(len(y))
答案 2 :(得分:2)
您可以使用itertools
创建所有排名,并使用numpy
解析结果。
>>> import numpy as np
>>> from itertools import product
>>> t = 5
>>> n = 10
>>> r = range(n+1)
# Create the product numpy array
>>> prod = np.fromiter(product(r, repeat=t), np.dtype('u1,' * t))
>>> prod = prod.view('u1').reshape(-1, t)
# Extract only permutations that satisfy a condition
>>> prod[prod.sum(axis=1) < n]
Timeit:
>>> %%timeit
prod = np.fromiter(product(r, repeat=t), np.dtype('u1,' * t))
prod = prod.view('u1').reshape(-1, t)
prod[prod.sum(axis=1) < n]
10 loops, best of 3: 41.6 ms per loop
您甚至可以通过populating combinations directly in numpy来加速产品计算。
答案 3 :(得分:0)
您可以使用动态编程优化算法。
基本上,有一个数组a
,其中a[i][j]
表示“我可以获得j
的总和,其中包含j-th
元素之前的元素(并使用{ {1}}元素,假设您的元素在数组jth
中(不是您提到的数字)。)
然后你可以填充数组
t
然后,使用此信息,您可以回溯解决方案:)
a[0][t[0]] = True
for i in range(1, len(t)):
a[i][t[i]] = True
for j in range(t[i]+1, n+1):
for k in range(0, i):
if a[k][j-t[i]]:
a[i][j] = True