使用python itertools来管理嵌套的for循环

时间:2015-03-20 19:39:54

标签: python arrays loops itertools

我正在尝试使用itertools.product来管理一些嵌套for循环的簿记,其中嵌套循环的数量事先不知道。下面是一个特定的例子,我选择了两个嵌套的for循环;选择两个只是为了清楚,我需要的是一个适用于任意数量的循环的解决方案。

此问题提供了此处出现的问题的扩展/概括: Efficient algorithm for evaluating a 1-d array of functions on a same-length 1d numpy array

现在我使用我在这里学到的itertools技巧扩展了上述技巧: Iterating over an unknown number of nested loops in python

序言:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [idx1, idx2]

func_table  = []
for items in product(*joint):
    f = trivial_functional(*items)
    func_table.append(f)

在上面的itertools循环结束时,我有一个12元素的1-d函数数组, func_table ,每个元素都是从trivial_functional构建的。

问题:

假设给出了一对整数(i_1,i_2),其中这些整数分别被解释为 idx1 idx2 的索引。如何使用itertools.product确定 func_table 数组的正确对应元素?

我知道如何通过编写模仿itertools.product簿记的函数来破解答案,但是当然有一个itertools.product的内置功能用于此目的吗?

4 个答案:

答案 0 :(得分:2)

除了自己动手之外,我不知道计算扁平指数的方法。幸运的是,这并不困难:

def product_flat_index(factors, indices):
  if len(factors) == 1: return indices[0]
  else: return indices[0] * len(factors[0]) + product_flat_index(factors[1:], indices[1:])

>> product_flat_index(joint, (2, 1))
9

另一种方法是首先将结果存储在嵌套数组中,不需要进行转换,但这更复杂:

from functools import reduce
from operator import getitem, setitem, itemgetter

def get_items(container, indices):
  return reduce(getitem, indices, container)

def set_items(container, indices, value):
  c = reduce(getitem, indices[:-1], container)
  setitem(c, indices[-1], value)

def initialize_table(lengths):
  if len(lengths) == 1: return [0] * lengths[0]
  subtable = initialize_table(lengths[1:])
  return [subtable[:] for _ in range(lengths[0])]

func_table = initialize_table(list(map(len, joint)))
for items in product(*map(enumerate, joint)):
  f = trivial_functional(*map(itemgetter(1), items))
  set_items(func_table, list(map(itemgetter(0), items)), f)

>>> get_items(func_table, (2, 1)) # same as func_table[2][1]
<function>

答案 1 :(得分:2)

由于每个人都有解决方案,所以很多答案都非常有用。

事实证明,如果我用Numpy重新解决这个问题,我可以完成相同的簿记,并解决我试图解决的问题,相对于纯python解决方案,速度大大提高。诀窍就是使用Numpy的 reshape 方法和普通的多维数组索引语法。

这是如何工作的。我们只是将 func_table 转换为Numpy数组,然后重新整形:

func_table = np.array(func_table)
component_dimensions = [len(idx1), len(idx2)]
func_table = np.array(func_table).reshape(component_dimensions)

现在 func_table 可用于返回正确的函数,不仅适用于单个2d点,还适用于完整的2d点数组:

dim1_pts = [3,1,2,1,3,3,1,3,0]
dim2_pts = [0,1,2,1,2,0,1,2,1]
func_array = func_table[dim1_pts, dim2_pts]
像往常一样,Numpy来救援!

答案 2 :(得分:1)

这有点乱,但是你走了:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [enumerate(idx1), enumerate(idx2)]

func_map  = {}
for indexes, items in map(lambda x: zip(*x), product(*joint)):
    f = trivial_functional(*items)
    func_map[indexes] = f

print(func_map[(2, 0)](5)) # 40 = (3+5)*5

答案 3 :(得分:0)

我建议在正确的位置使用enumerate()

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [idx1, idx2]

func_table  = []
for items in product(*joint):
     f = trivial_functional(*items)
     func_table.append(f)

根据我对您的评论和代码的理解,func_table只是通过序列中某个输入的出现来编入索引。您可以使用以下方式再次访问它:

for index, items in enumerate(product(*joint)):
    # because of the append(), index is now the 
    # position of the function created from the 
    # respective tuple in join()
    func_table[index](some_value)