我的目标是找到nxn矩阵中对角线的任何排列(2 <= n <= 15)。矩阵由零和1组成。
目前我这样做:
indices = [[j for j, x in enumerate(row) if x == 1]
for row in self.matrix]
cart = list(itertools.product(*indices))
cart = [list(tup) for tup in cart]
cart = filter(lambda dia: len(list(set(dia))) == len(dia), cart)
return cart
如果矩阵不是太大,这可以正常工作,但是否则会失败: 的MemoryError
那么有没有办法避免购物车的整个计算?为了找到例如一个排列,计算停止了吗?
答案 0 :(得分:1)
不要在itertools.product
的结果上调用list
并使用itertools.ifilter
代替filter
来简单地使所有评估变得懒惰:
from itertools import ifilter, product
indices = [[j for j, x in enumerate(row) if x == 1] for row in self.matrix]
cart = product(*indices)
found_cart = next(ifilter(lambda dia: len(set(dia)) == len(dia), cart), None)
next
返回ifilter
中谓词为True
的第一种情况,或者在没有匹配项的情况下返回None
。
一旦找到匹配项,计算就会停止。
答案 1 :(得分:0)
您可以简化代码的最后一部分,让它只返回第一个答案:
def foo(matrix):
indices = [[j for j, x in enumerate(row) if x == 1] for row in matrix]
# this part is changed, very simple and efficient now
for dia in itertools.product(*indices):
if len(set(dia)) == len(dia):
return dia
换句话说,不要那么聪明filter
和lambda以及所有这些 - 这是不必要的。