使用python中的索引展平嵌套列表

时间:2019-03-11 05:57:58

标签: python-3.x

我有一个列表['','','',['',[['a','b']['c']]],[[['a','b'],['c']]],[[['d']]]]

我想用索引将列表弄平,输出应如下:

flat list=['','','','','a','b','c','a','b','c','d']
indices=[0,1,2,3,3,3,3,4,4,4,5]

该怎么做?

我已经尝试过了:

def flat(nums):
    res = []
    index = []
    for i in range(len(nums)):
        if isinstance(nums[i], list):
            res.extend(nums[i])
            index.extend([i]*len(nums[i]))
        else:
            res.append(nums[i])
            index.append(i)
    return res,index

但这不能按预期工作。

3 个答案:

答案 0 :(得分:4)

TL; DR

此实现可处理深度无限的嵌套可迭代对象:

def enumerate_items_from(iterable):
    cursor_stack = [iter(iterable)]
    item_index = -1
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:
            cursor_stack.pop()
            continue
        if len(cursor_stack) == 1:
            item_index += 1
        if not isinstance(item, str):
            try:
                cursor_stack.append(iter(item))
                continue
            except TypeError:
                pass
        yield item, item_index

def flat(iterable):
    return map(list, zip(*enumerate_items_from(a)))

可用于产生所需的输出:


>>> nested = ['', '', '', ['', [['a', 'b'], ['c']]], [[['a', 'b'], ['c']]], [[['d']]]]
>>> flat_list, item_indexes = flat(nested)
>>> print(item_indexes)
[0, 1, 2, 3, 3, 3, 3, 4, 4, 4, 5]
>>> print(flat_list)
['', '', '', '', 'a', 'b', 'c', 'a', 'b', 'c', 'd']

请注意,您可能应该将索引放在第一位以模仿enumerate的行为。对于已经认识enumerate的人来说,使用起来会更容易。

重要提示,除非您确定列表不会嵌套太多,否则不应使用任何基于递归的解决方案。否则,一旦您拥有深度大于1000的嵌套列表,您的代码就会崩溃。我将对此here进行说明。请注意,对str(list)的简单调用将在使用depth > 1000的测试用例上崩溃(对于某些python实现,它不止于此,但始终受限制)。使用基于递归的解决方案时,典型的例外是(简而言之是由于python调用堆栈的工作原理):

RecursionError: maximum recursion depth exceeded ... 

实施细节

我将逐步进行操作,首先将平整一个列表,然后输出平整后的列表和所有项目的深度,最后我们将在“ main”中输出该列表和相应的项目索引清单”。

整理列表

话虽如此,这实际上是很有趣的,因为迭代解决方案是专门为此设计的,您可以采用一种简单的(非递归)列表展平算法:

def flatten(iterable):
    return list(items_from(iterable))

def items_from(iterable):
    cursor_stack = [iter(iterable)]
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:       # post-order
            cursor_stack.pop()
            continue
        if isinstance(item, list):  # pre-order
            cursor_stack.append(iter(item))
        else:
            yield item              # in-order

计算深度

我们可以通过查看堆栈大小depth = len(cursor_stack) - 1

来访问深度
        else:
            yield item, len(cursor_stack) - 1      # in-order

这将返回对(项目,深度)的迭代,如果我们需要将结果分成两个迭代器,我们可以使用zip函数:

>>> a = [1,  2,  3, [4 , [[5, 6], [7]]], [[[8, 9], [10]]], [[[11]]]]
>>> flatten(a)
[(1, 0), (2, 0), (3, 0), (4, 1), (5, 3), (6, 3), (7, 3), (8, 3), (9, 3), (10, 3), (11, 3)]
>>> flat_list, depths = zip(*flatten(a))
>>> print(flat_list)
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)
>>> print(depths)
(0, 0, 0, 1, 3, 3, 3, 3, 3, 3, 3)

我们现在将执行类似的操作,以使用项目索引而不是深度。

计算项目索引

要计算项目索引(在主列表中),您需要计算到目前为止已看到的项目数,这可以通过在每次迭代时将item_index加1来完成在深度为0的项目上(当堆栈大小等于1时):

def flatten(iterable):
    return list(items_from(iterable))

def items_from(iterable):
    cursor_stack = [iter(iterable)]
    item_index = -1
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:             # post-order
            cursor_stack.pop()
            continue
        if len(cursor_stack) == 1:        # If current item is in "main" list
            item_index += 1               
        if isinstance(item, list):        # pre-order
            cursor_stack.append(iter(item))
        else:
            yield item, item_index        # in-order

类似地,我们将使用ˋzip , we will also use ˋmap将对分解为两个itératifs,以将两个迭代器都转换为列表:

>>> a = [1,  2,  3, [4 , [[5, 6], [7]]], [[[8, 9], [10]]], [[[11]]]]
>>> flat_list, item_indexes = map(list, zip(*flatten(a)))
>>> print(flat_list)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
>>> print(item_indexes)
[0, 1, 2, 3, 3, 3, 3, 4, 4, 4, 5]

改进-处理可迭代的输入

可能希望能够使用更大范围的嵌套可迭代项作为输入(特别是如果将其构建为供其他人使用)。例如,如果将嵌套的可迭代对象作为输入,则当前的实现将无法按预期运行,例如:

>>> a = iter([1, '2',  3, iter([4, [[5, 6], [7]]])])
>>> flat_list, item_indexes = map(list, zip(*flatten(a)))
>>> print(flat_list)
[1, '2', 3, <list_iterator object at 0x100f6a390>]
>>> print(item_indexes)
[0, 1, 2, 3]

如果我们希望此方法有效,则需要谨慎一点,因为字符串是可迭代的,但我们希望将它们视为原子项(而不是字符列表)。与其像以前那样假设输入是列表,就没有:

        if isinstance(item, list):        # pre-order
            cursor_stack.append(iter(item))
        else:
            yield item, item_index        # in-order

我们不会检查输入类型,而是将其视为可迭代的来使用,如果输入失败,我们将知道它不是可迭代的(鸭子输入):

       if not isinstance(item, str):
            try:
                cursor_stack.append(iter(item))
                continue
            # item is not an iterable object:
            except TypeError:
                pass
        yield item, item_index

通过此实现,我们可以:

>>> a = iter([1, 2,  3, iter([4, [[5, 6], [7]]])])
>>> flat_list, item_indexes = map(list, zip(*flatten(a)))
>>> print(flat_list)
[1, 2, 3, 4, 5, 6, 7]
>>> print(item_indexes)
[0, 1, 2, 3, 3, 3, 3]

构建测试用例

如果您需要生成深度较大的测试用例,则可以使用以下代码:

def build_deep_list(depth):
    """Returns a list of the form $l_{depth} = [depth-1, l_{depth-1}]$
    with $depth > 1$ and $l_0 = [0]$.
    """
    sub_list = [0]
    for d in range(1, depth):
        sub_list = [d, sub_list]
    return sub_list

您可以使用它来确保深度不大时我的实现不会崩溃:

a = build_deep_list(1200)
flat_list, item_indexes = map(list, zip(*flatten(a)))

我们还可以使用str函数来检查是否无法打印出这样的列表:

>>> a = build_deep_list(1200)
>>> str(a)
RecursionError: maximum recursion depth exceeded while getting the repr of an object

函数repr在输入str(list)的每个元素上由list调用。

结束语

最后,我同意递归实现更容易阅读(因为调用栈为我们完成了一半的辛苦工作),但是当实现这样的低级功能时,我认为拥有一个代码可以很好地投资在所有情况下(或至少您能想到的所有情况下)都有效。尤其是当解决方案不是那么困难时。这也是一种不忘记如何编写在树状结构上工作的非递归代码的方法(除非您自己实现数据结构,否则这种情况可能不会发生很多,但这是一个好习惯)。

请注意,我所说的“反对”递归只是因为python在面对递归Tail Recursion Elimination in Python时不会优化调用堆栈的使用。许多编译语言都使用Tail Call recursion Optimization (TCO)。这意味着,即使您用python编写了完美的tail-recursive函数,它也会在深层嵌套的列表上崩溃。

如果您需要有关列表展平算法的更多详细信息,请参阅我链接的帖子。

答案 1 :(得分:0)

简单而优雅的解决方案:

def flat(main_list):

    res = []
    index = []

    for main_index in range(len(main_list)):
        # Check if element is a String
        if isinstance(main_list[main_index], str):
            res.append(main_list[main_index])
            index.append(main_index)

        # Check if element is a List
        else:
            sub_list = str(main_list[main_index]).replace('[', '').replace(']', '').replace(" ", '').replace("'", '').split(',')
            res += sub_list
            index += ([main_index] * len(sub_list))

    return res, index

答案 2 :(得分:0)

这可以完成工作,但是如果您希望将其退回,那么我会为您增强功能

from pprint import pprint

ar = ["","","",["",[["a","b"],["c"]]],[[["a","b"],["c"]]],[[["d"]]]]
flat = []
indices= []

def squash(arr,indx=-1):
    for ind,item in enumerate(arr):
        if isinstance(item, list):
            squash(item,ind if indx==-1 else indx)
        else:
            flat.append(item)
            indices.append(ind if indx==-1 else indx)

squash(ar)

pprint(ar)
pprint(flat)
pprint(indices)

编辑

这是如果您不想将列表保留在内存中并返回它们

from pprint import pprint

ar = ["","","",["",[["a","b"],["c"]]],[[["a","b"],["c"]]],[[["d"]]]]

def squash(arr,indx=-1,fl=[],indc=[]):
    for ind,item in enumerate(arr):
        if isinstance(item, list):
            fl,indc = squash(item,ind if indx==-1 else indx, fl, indc)
        else:
            fl.append(item)
            indc.append(ind if indx==-1 else indx)
    return fl,indc

flat,indices = squash(ar)

pprint(ar)
pprint(flat)
pprint(indices)

我不希望您需要超过1k的递归深度,这是默认设置