分布的迭代排列

时间:2018-06-06 22:10:28

标签: python recursion iteration permutation distribution

我正在尝试生成各种分布的所有可能组合。

例如,假设您在4个类别上花费5分,但您只能在任何给定类别上花费最多2分。 在这种情况下,所有可能的解决方案如下:

[0, 1, 2, 2]
[0, 2, 1, 2]
[0, 2, 2, 1]
[1, 0, 2, 2]
[1, 1, 1, 2]
[1, 1, 2, 1]
[1, 2, 0, 2]
[1, 2, 1, 1]
[1, 2, 2, 0]
[2, 0, 1, 2]
[2, 0, 2, 1]
[2, 1, 0, 2]
[2, 1, 1, 1]
[2, 1, 2, 0]
[2, 2, 0, 1]
[2, 2, 1, 0]

我已经成功地创建了一个完成此功能的递归函数,但是对于大量类别,生成需要很长时间。我试图制作一个迭代函数,希望加快速度,但我似乎无法将其考虑到类别最大值。

这是我的递归函数(count = points,dist =零填充数组,与max_allo大小相同)

def distribute_recursive(count, max_allo, dist, depth=0):
    for ration in range(max(count - sum(max_allo[depth + 1:]), 0), min(count, max_allo[depth]) + 1):
        dist[depth] = ration
        count -= ration
        if depth + 1 < len(dist):
            distribute_recursive(count, max_allo, dist, depth + 1)
        else:
            print(dist)
        count += ration

2 个答案:

答案 0 :(得分:3)

递归不慢

递归不是让它变慢的原因;考虑一个更好的算法

def dist (count, limit, points, acc = []):
  if count is 0:
    if sum (acc) is points:
      yield acc
  else:
    for x in range (limit + 1):
      yield from dist (count - 1, limit, points, acc + [x])

您可以在列表中收集生成的结果

print (list (dist (count = 4, limit = 2, points = 5)))

修剪无效组合

在上面,我们使用limit + 1的固定范围,但要注意如果我们与(例如)limit = 2points = 5生成组合会发生什么...

[ 2, ... ]    # 3 points remaining
[ 2, 2, ... ] # 1 point remaining

此时,使用limit + 1[ 0, 1, 2 ])的固定范围是愚蠢的,因为我们知道我们只剩下1分。这里剩下的唯一选项是01 ...

[ 2, 2, 1 ... ] # 0 points remaining

上面我们知道我们可以使用[ 0 ]的空范围,因为没有剩余的花费。这将阻止我们尝试验证像

这样的组合
[ 2, 2, 2, ... ] # -1 points remaining
[ 2, 2, 2, 0, ... ] # -1 points remaining
[ 2, 2, 2, 1, ... ] # -2 points remaining
[ 2, 2, 2, 2, ... ] # -3 points remaining

如果count非常大,则可以排除巨大数量的无效组合

[ 2, 2, 2, 2, 2, 2, 2, 2, 2, ... ] # -15 points remaining 

要实现此优化,我们可以在dist函数中添加另一个参数,但是在5个参数中,它会开始变得混乱。相反,我们引入了一个辅助函数来控制loop。添加我们的优化,我们交换固定范围的动态范围min (limit, remaining) + 1。最后,由于我们知道已经分配了多少点,我们不再需要测试每个组合的sum;从我们的算法中删除了另一个昂贵的操作

# revision: prune invalid combinations
def dist (count, limit, points):
  def loop (count, remaining, acc):
    if count is 0:
      if remaining is 0:
        yield acc
    else:
      for x in range (min (limit, remaining) + 1):
        yield from loop (count - 1, remaining - x, acc + [x])
  yield from loop (count, points, [])

<强>基准

在下面的基准测试中,我们程序的第一个版本已重命名为dist1,而使用动态范围dist2的程序更快。我们设置了三个测试,smallmediumlarge

def small (prg):
  return list (prg (count = 4, limit = 2, points = 5))

def medium (prg):
  return list (prg (count = 8, limit = 3, points = 7))

def large (prg):
  return list (prg (count = 16, limit = 5, points = 10))

现在我们运行测试,将每个程序作为参数传递。请注意large测试,只需完成1次传递,因为dist1需要一段时间才能生成结果

print (timeit ('small (dist1)', number = 10000, globals = globals ()))
print (timeit ('small (dist2)', number = 10000, globals = globals ()))

print (timeit ('medium (dist1)', number = 100, globals = globals ()))
print (timeit ('medium (dist2)', number = 100, globals = globals ()))

print (timeit ('large (dist1)', number = 1, globals = globals ()))
print (timeit ('large (dist2)', number = 1, globals = globals ()))

small测试的结果显示,修剪无效组合并没有太大的区别。但是在mediumlarge案例中,差异非常大。我们的旧程序需要30分钟以上的大型设备,但使用新程序只需1秒钟!

dist1 small      0.8512216459494084
dist2 small      0.8610155049245805   (0.98x speed-up)

dist1 medium     6.142372329952195
dist2 medium     0.9355670949444175   (6.57x speed-up)

dist1 large   1933.0877765258774
dist2 large      1.4107366011012346   (1370.26x speed-up)

对于参考框架,每个结果的大小打印在

下面
print (len (small (dist2)))   # 16      (this is the example in your question)
print (len (medium (dist2)))  # 2472
print (len (large (dist2)))   # 336336

检查我们的理解

在使用largecount = 12的{​​{1}}基准测试中,使用我们未优化的程序,我们迭代了5 12 或244,140,​​625 可能组合。使用我们优化的程序,我们跳过所有无效组合,产生336,336个有效答案。通过单独分析组合计数,我们发现惊人的99.86%的可能组合无效。如果对每个组合的分析花费相同的时间,我们可以预期,由于无效的组合修剪,我们优化的程序的性能至少会提高725.88x。

limit = 5基准测试中,测量速度提高了1370.26倍,优化后的计划符合我们的预期,甚至超越了我们的预期。额外的加速可能是因为我们取消了对large

的调用

<强> huuuuge

要显示此技术适用于超大型数据集,请考虑sum基准。我们的计划在7个 16 或33,232,930,569,601种可能性中找到了17,321,844个有效组合。

在此测试中,我们的优化程序修剪了99.99479%的无效组合。将这些数字与之前的数据集相关联,我们估计优化程序比未优化版本快1,918,556.16x。

使用未经优化的计划的该基准的理论运行时间 117。60年。优化后的程序只需1分钟即可找到答案。

huge

答案 1 :(得分:2)

您可以使用生成器函数进行递归,同时应用其他逻辑来减少所需的递归调用次数:

def listings(_cat, points, _max, current = []):
   if len(current) == _cat:
      yield current
   else:
      for i in range(_max+1):
        if sum(current+[i]) <= points:
          if sum(current+[i]) == points or len(current+[i]) < _cat:
             yield from listings(_cat, points, _max, current+[i])


print(list(listings(4, 5, 2)))

输出:

[[0, 1, 2, 2], [0, 2, 1, 2], [0, 2, 2, 1], [1, 0, 2, 2], [1, 1, 1, 2], [1, 1, 2, 1], [1, 2, 0, 2], [1, 2, 1, 1], [1, 2, 2, 0], [2, 0, 1, 2], [2, 0, 2, 1], [2, 1, 0, 2], [2, 1, 1, 1], [2, 1, 2, 0], [2, 2, 0, 1], [2, 2, 1, 0]]

虽然目前还不清楚您的解决方案大小减慢了什么类别大小,但对于类别大小最多24,此解决方案将在一秒钟内运行,搜索总共五个点,最大插槽值为2 。请注意,对于大点和槽值,在一秒钟内计算的可能类别大小的数量会增加:

import time

def timeit(f):
   def wrapper(*args):
     c = time.time()
     _ = f(*args)
     return time.time() - c
   return wrapper

@timeit
def wrap_calls(category_size:int) -> float:
  _ = list(listings(category_size, 5, 2))

benchmark = 0
category_size = 1
while benchmark < 1:
   benchmark = wrap_calls(category_size)
   category_size += 1

print(category_size)

输出:

24