列表列表上的递归生成器

时间:2014-12-02 20:23:20

标签: python recursion generator

我有2个降级列表,看起来像这样:

[[1,2,3,],
[4,5,6],
[7,8,9]] 

我正在尝试编写一个产生“路径”总和的生成器。

'路径'从左上角开始,仅在x + 1和y + 1上开始,直到它到达它的最后一个元素(右下角)。

例如,有效路径为1=>2=>5=>6=>9 (sum=23)

无效路径可以是1=>2=>5=>**4**=> ...

到目前为止,我有这段代码:

my_list = [[0, 2, 5], [1, 1, 3], [2, 1, 1]]   
def gen(x, y, _sum):
    if x + 1 <= len(my_list):
        for i1 in gen(x + 1, y, _sum + my_list[y][x]):
            yield _sum
    if y + 1 <= len(my_list):
        for i2 in gen(x, y + 1, _sum + my_list[y][x]):
            yield _sum
    yield _sum + my_list[y][x]


g = gen(0, 0, 0)
total = 0
for elm in g:
    total += elm

print total

我收到错误:

  for i2 in gen(x, y+1, _sum+my_list[y][x]):
IndexError: list index out of range

2 个答案:

答案 0 :(得分:2)

此错误的原因是一个简单的逐个错误。*

我认为您想要的是x <= len(my_list)或等同于x+1 < len(my_list);你已经加倍了+ 1-ness,导致你跑过列表的末尾。

考虑具体案例:

  • len(my_list)为3. x2。因此,x+1 <= len(my_list)3 <= 3,这是真的。所以你用gen(3, …)递归地称呼自己。
  • 在该递归调用中,4 <= 3为false,因此,根据y的值,您可以调用:
    • gen(x, y + 1, _sum + my_list[y][3])
    • _sum + my_list[y][3]
    • ...其中任何一个都会引发IndexError

显然,您需要使用y解决与x相同的问题。

您可以看到它正常运行here


当然它实际上并没有打印出正确的结果,因为您的代码中还有其他问题。在我的头顶:

  • total = + elmtotal中的任何内容替换为elm的值。您可能想要+=,而不是= +
  • 反复产生_sum并忽略递归生成器产生的值不可能有任何好处。也许您想要产生i1i2呢?

我不能保证这些是您代码中唯一的问题,只是它们是问题。


*我在这里假设这是一个愚蠢的错误,而不是一个基本的错误 - 你清楚地知道索引是基于0的,因为你用gen(0, 0, 0)调用了函数而不是{{1 }}

答案 1 :(得分:2)

如果你真的想要通过N×M矩阵强制所有允许的路径,那么只需生成N-1向右移动的所有排列加上M-1向下移动,然后使用这些移动来对值进行求和路径:

from itertools import permutations

def gen_path_sum(matrix):
    N, M = len(matrix), len(matrix[0])
    for path in permutations([(1, 0)] * (N - 1) + [(0, 1)] * (M - 1)):
        sum = matrix[0][0]
        x = y = 0
        for dx, dy in path:
            x += dx; y += dy
            sum += matrix[x][y]
        yield sum

这将产生(N + M)!路径;对于3乘3矩阵,有720个这样的路径。

但是,如果您试图通过矩阵找到最大路径,那么您将采用效率低下的方式。

您可以改为计算矩阵中任何单元格的最大路径;它只是上方和左侧单元格的最大路径值中的最大值,加上当前单元格的值。因此,对于左上角的单元格(上方或右侧没有单元格),最大路径值是单元格的值。

您可以使用N X M循环计算所有这些值:

def max_path_value(matrix):
    totals = [row[:] for row in matrix]
    for x, row in enumerate(totals):
        for y, cell in enumerate(row):
            totals[x][y] += max(
                totals[x - 1][y] if x else 0,
                totals[x][y - 1] if y else 0
            )
    return totals[-1][-1]

对于3乘3矩阵,这只需要N X M步,或总共9步。这比蛮力方法好80倍。

对比度仅随着矩阵大小的增加而增加;一个10x10矩阵,暴力强制,需要检查2432902008176640000路径(== 20!),或者你可以用100步计算最大路径。