itertools.product如何在不将中间结果保存在内存中的情况下计算笛卡尔积

时间:2019-03-15 14:22:44

标签: python recursion itertools cartesian-product

根据文档here,iterpools.product在内存中不保存中间结果(它计算输入列表的笛卡尔乘积)。但是给出的算法的粗略草图使我相信它确实可以做到。请注意结果如何通过在结果中添加元素并向其添加更多内容而在每次迭代中保持更新。

def product(*args, repeat=1):
    # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
    pools = [tuple(pool) for pool in args] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    for prod in result:
        yield tuple(prod)

我尝试添加底层的 C 代码here,但没有这样做。我想了解 C 代码如何在不将中间结果保存在内存中的情况下工作。我遇到了一种递归方法(如下所示),除了递归调用堆栈,该方法不会将中间结果保留在内存中。 C代码是否还使用递归方法,否则它如何在不将中间结果保存在内存中的情况下工作?

// Recursive approach
def product(input, ans, i): 
    if i == len(input): 
        print(ans) 
        return 
    j = 0 
    while j < len(input[i]): 
        ans.append(input[i][j]) 
        find(input, ans, i+1) 
        ans.pop() 
        j += 1

input = [] 
input.append(["hi", "hey"]) 
input.append(["there", "here"]) 
input.append(["cute", "handsome"])  

ans = [] 
print(product(input, ans, 0))

hi there cute
hi there handsome
....

1 个答案:

答案 0 :(得分:2)

它将输入(以tuple的形式存储在内存中,以及每个tuple的索引,并重复循环除第一个以外的所有内容。每次请求一个新的输出值时,它:

  1. 将索引前进到最右边的tuple
  2. 如果它超过结尾,它将重置为零,并前进到下一个最右边的索引
  3. 重复步骤2,直到找到一个无需增加其特定迭代器末尾即可增加的索引
  4. 通过为每个数据源提取当前索引处的值来创建新的tuple

在第一次拉入时有一个特殊情况,即它仅从每个tuple中拉取第0个值,否则每次都遵循该模式。

对于一个非常简单的示例,其内部状态为:

for x, y in product('ab', 'cd'):

将首先创建元组('a', 'b')('c', 'd'),以及最初创建的索引数组[0, 0]。第一次拉动时,它会产生('a', 'b')[0], ('c', 'd')[0]('a', 'c')。在下一次拉动时,它将索引数组前进到[0, 1],并产生('a', 'b')[0], ('c', 'd')[1]('a', 'd')。下一个拉取将最右边的索引前进到2,意识到它已经溢出,将其放回0,并前进下一个索引使其成为[1, 0],并产生('a', 'b')[1], ('c', 'd')[0]('b', 'c')。这一直持续到最左边的索引溢出为止,此时迭代完成。

实际上等效的Python代码看起来更像:

def product(*iterables, repeat=1):
    tuples = [tuple(it) for it in iterables] * repeat
    if not all(tuples):
        return # A single empty input means nothing to yield
    indices = [0] * len(tuples)
    yield tuple(t[i] for i, t in zip(indices, tuples))
    while True:
        # Advance from rightmost index moving left until one of them
        # doesn't cycle back to 0
        for i in range(len(indices))[::-1]:
            indices[i] += 1
            if indices[i] < len(tuples[i]):
                break  # Done advancing for this round
            indices[i] = 0  # Cycle back to 0, advance next
        else:
            # The above loop will break at some point unless
            # the leftmost index gets cycled back to 0
            # (because all the leftmost values have been used)
            # so if reach the else case, all products have been computed
            return

        yield tuple(t[i] for i, t in zip(indices, tuples))

但是您可以看到,它比提供的简单版本复杂得多。

如您所见,每个输出tuple在创建后立即yield被创建;只有输入和这些输入的当前索引必须保留为迭代器状态。因此,只要调用者不存储结果,而只进行实时迭代,则只需要很少的内存。