更多pythonic编写这种递归函数的方法

时间:2017-01-06 20:50:08

标签: python recursion list-comprehension

此函数用于获取通用字典(可能递归地包含其他字典和列表),并将其所有内容放在单个线性列表中。

def make_a_list(a):
  print type(a)
  if (type(a) == type({})):
    return make_a_list(a.keys()) + make_a_list(a.values())
  elif (type(a) == type([])):
    if len(a) > 1:
      return make_a_list(a[0]) + make_a_list(a[1:])
    return a
  else:
    return [a]

它完成了它的工作,但我想知道: a)我忘记了任何重要的数据类型吗? (例如,我忘了套) b)写一个更加pythonic的方式是什么? (尤其是我可以写一个列表理解吗?)

3 个答案:

答案 0 :(得分:2)

您可以使用yield

来避免功能中的列表创建/连接
def make_a_list(a):
  if isinstance(a, dict):
    yield from make_a_list(a.keys())
    yield from make_a_list(a.values())
  elif isinstance(a, (list, tuple, set)):
    for x in a:
      yield from make_a_list(x)
  else:
    yield a

这是一个生成器,所以如果你真的需要一个列表,你可以这样做:

def make_a_real_list(a):
    return list(make_a_list(a))

另请注意isinstance比直接比较类型更好。

答案 1 :(得分:0)

您总是可以使用自己的堆栈/队列来摆脱这些简单类型的递归。

首先,使用评论中建议的正确测试,重要的是__iter__事物。

然后:

things_to_flatten = []
things_to_flatten.append(a)
new_list = []
while things_to_flatten:
  current = things_to_flatten.pop(0)
  if isinstance(current, dict):
      things_to_flatten.extend(current.keys())
      things_to_flatten.extend(current.values())
  elif hasattr(current, '__iter__'):
      things_to_flatten.extend(current)
  else:
      new_list.append(current)

可能会进行一些调整以提高效率,例如意识到字典键不能成为词典或列表。但是,它们可能是元组,而且它们是可迭代的,所以......最好坚持一般检查。

答案 2 :(得分:0)

我可以推荐以下解决方案吗? mainmake_a_list函数会测试您的想法并显示更好的方法来分别实现它。如果您不介意使用可迭代对象和生成器的概念,test函数和flatten生成器可以更好地演示如何解决问题。请调整您的代码以获得最佳效果并且效果更佳。

#! /usr/bin/env python3
def main():
    obj = 1
    print('make_a_list({!r}) = {!r}'.format(obj, make_a_list(obj)))
    obj = {1, 2, 3}
    print('make_a_list({!r}) = {!r}'.format(obj, make_a_list(obj)))
    obj = [1, 2, 3]
    print('make_a_list({!r}) = {!r}'.format(obj, make_a_list(obj)))
    obj = [1]
    print('make_a_list({!r}) = {!r}'.format(obj, make_a_list(obj)))
    obj = 'a', 'b', 'c'
    print('make_a_list({!r}) = {!r}'.format(obj, make_a_list(obj)))
    obj = {1: 2, 3: 4, 5: 6}
    print('make_a_list({!r}) = {!r}'.format(obj, make_a_list(obj)), end='\n\n')


def make_a_list(obj):
    if isinstance(obj, dict):
        return make_a_list(list(obj.keys())) + make_a_list(list(obj.values()))
    if isinstance(obj, list):
        if len(obj) > 1:
            return make_a_list(obj[0]) + make_a_list(obj[1:])
        return obj
    return [obj]


def test():
    obj = 1
    print('list(flatten({!r})) = {!r}'.format(obj, list(flatten(obj))))
    obj = {1, 2, 3}
    print('list(flatten({!r})) = {!r}'.format(obj, list(flatten(obj))))
    obj = [1, 2, 3]
    print('list(flatten({!r})) = {!r}'.format(obj, list(flatten(obj))))
    obj = [1]
    print('list(flatten({!r})) = {!r}'.format(obj, list(flatten(obj))))
    obj = 'a', 'b', 'c'
    print('list(flatten({!r})) = {!r}'.format(obj, list(flatten(obj))))
    obj = {1: 2, 3: 4, 5: 6}
    print('list(flatten({!r})) = {!r}'.format(obj, list(flatten(obj))))


def flatten(iterable):
    if isinstance(iterable, (list, tuple, set, frozenset)):
        for item in iterable:
            yield from flatten(item)
    elif isinstance(iterable, dict):
        for item in iterable.keys():
            yield from flatten(item)
        for item in iterable.values():
            yield from flatten(item)
    else:
        yield iterable


if __name__ == '__main__':
    main()
    test()