Python:在生成器对象上调用list()会产生不正确的结果

时间:2018-02-16 05:55:05

标签: python list generator yield iterable

我正在研究this question的公认解决方案,它提供了一种算法的Python实现,用于按字典顺序生成唯一的排列。我有一个缩短的实现:

def permutations(seq):
    seq = sorted(seq)
    while True:
        yield seq
        k = l = None
        for k in range(len(seq) - 1):
            if seq[k] < seq[k + 1]:
                l = k + 1
                break
        else:
            return

        (seq[k], seq[l]) = (seq[l], seq[k])
        seq[k + 1:] = seq[-1:k:-1]

对我来说真正奇怪的是,如果我在此函数的输出上调用list,我会得到错误的结果。但是,如果我一次迭代一次这个函数的结果,我会得到预期的结果。

>>> list(permutations((1,2,1)))
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]
>>> for p in permutations((1,2,1)):
...   print(p)
... 
[1, 1, 2]
[1, 2, 1]
[2, 1, 1]

^^^什么?!另一个例子:

>>> list(permutations((1,2,3)))
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]
>>> for p in permutations((1,2,3)):
...   print(p)
... 
[1, 2, 3]
[2, 3, 1]
[3, 1, 2]
[3, 2, 1]

列表推导也会产生不正确的值:

>>> [p for p in permutations((1,2,3))]
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]

我不知道这里发生了什么!我以前没见过这个。我可以编写其他使用生成器的函数,我不会遇到这个:

>>> def seq(n):
...   for i in range(n):
...     yield i
... 
>>> list(seq(5))
[0, 1, 2, 3, 4]

上面我的例子中发生了什么导致这种情况?

1 个答案:

答案 0 :(得分:8)

您在生成之后修改生成器中的seq。你继续产生相同的对象,并修改它。

    (seq[k], seq[l]) = (seq[l], seq[k]) # this mutates seq
    seq[k + 1:] = seq[-1:k:-1] # this mutates seq

注意,您的list多次包含同一个对象

In [2]: ps = list(permutations((1,2,1)))

In [3]: ps
Out[3]: [[2, 1, 1], [2, 1, 1], [2, 1, 1]]

In [4]: [hex(id(p)) for p in ps]
Out[4]: ['0x105cb3b48', '0x105cb3b48', '0x105cb3b48']

所以,试试yield副本:

def permutations(seq):
    seq = sorted(seq)
    while True:
        yield seq.copy()
        k = None
        l = None
        for k in range(len(seq) - 1):
            if seq[k] < seq[k + 1]:
                l = k + 1
                break
        else:
            return

        (seq[k], seq[l]) = (seq[l], seq[k])
        seq[k + 1:] = seq[-1:k:-1]

而且,瞧:

In [5]: def permutations(seq):
   ...:     seq = sorted(seq)
   ...:     while True:
   ...:         yield seq.copy()
   ...:         k = None
   ...:         l = None
   ...:         for k in range(len(seq) - 1):
   ...:             if seq[k] < seq[k + 1]:
   ...:                 l = k + 1
   ...:                 break
   ...:         else:
   ...:             return
   ...:
   ...:         (seq[k], seq[l]) = (seq[l], seq[k])
   ...:         seq[k + 1:] = seq[-1:k:-1]
   ...:

In [6]: ps = list(permutations((1,2,1)))

In [7]: ps
Out[7]: [[1, 1, 2], [1, 2, 1], [2, 1, 1]]

至于为什么{for}循环中的print没有揭示这种行为,那是因为在迭代seq中的那一刻具有“正确”值,所以请考虑:

In [10]: result = []
    ...: for i, x in enumerate(permutations((1,2,1))):
    ...:     print("iteration ", i)
    ...:     print(x)
    ...:     result.append(x)
    ...:     print(result)
    ...:
iteration  0
[1, 1, 2]
[[1, 1, 2]]
iteration  1
[1, 2, 1]
[[1, 2, 1], [1, 2, 1]]
iteration  2
[2, 1, 1]
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]