我有4个数组的阵列(set1,set2,...)。 E.g。
set1 = [array([1, 0, 0]), array([-1, 0, 0]), array([0, 1, 0]), ...]
我需要找到多少向量组合总和为零。解决这个问题的简单方法是:
for b1 in set1:
for b2 in set2:
for b3 in set3:
for b4 in set4:
if all(b1 + b2 + b3 + b4 == 0):
count = count + 1
然而,这就像O(n ^ 4),并且基于3sum算法,我假设我可以做O(n ^ 3)并且速度非常重要。如何在python中快速完成这个任务?
答案 0 :(得分:1)
使用numpy的meshgrid函数:
http://docs.scipy.org/doc/numpy/reference/generated/numpy.meshgrid.html
你需要将初始集重塑为1-D,但这并不是为了这个目的。
set1 = set1.flatten() // etc
然后调用meshgrid()。它将为您提供4个4-D阵列,每个阵列一个。然后添加:
a,b,c,d = np.meshgrid(set1, set2, set3, set4)
total = a+b+c+d
最后,计算总数组中0的数量:
count = len(total) - np.count_nonzero(sum)
答案 1 :(得分:1)
假设输入是一维数组的列表,如问题中提供的示例数据中所列,似乎您可以在对行输入列表进行行堆叠后使用broadcasting
,如此 -
import numpy as np
s1 = np.row_stack((set1))
s2 = np.row_stack((set2))
s3 = np.row_stack((set3))
s4 = np.row_stack((set4))
sums = s4[None,None,None,:,:] + s3[None,None,:,None,:] + s2[None,:,None,None,:] + s1[:,None,None,None,:]
count = (sums.reshape(-1,s1.shape[1])==0).all(1).sum()
示例运行 -
In [319]: set1 = [np.array([1, 0, 0]), np.array([-1, 0, 0]), np.array([0, 1, 0])]
...: set2 = [np.array([-1, 0, 0]), np.array([-1, 1, 0])]
...: set3 = [np.array([1, 0, 0]), np.array([-1, 0, 0]), np.array([0, 1, 0])]
...: set4 = [np.array([1, 0, 0]), np.array([-1, 0, 0]), np.array([0, 1, 0]), np.array([0, 1, 0])]
...:
In [320]: count = 0
...: for b1 in set1:
...: for b2 in set2:
...: for b3 in set3:
...: for b4 in set4:
...: if all(b1 + b2 + b3 + b4 == 0):
...: count = count + 1
...:
In [321]: count
Out[321]: 3
In [322]: s1 = np.row_stack((set1))
...: s2 = np.row_stack((set2))
...: s3 = np.row_stack((set3))
...: s4 = np.row_stack((set4))
...:
...: sums = s4[None,None,None,:,:] + s3[None,None,:,None,:] + s2[None,:,None,None,:] + s1[:,None,None,None,:]
...: count2 = (sums.reshape(-1,s1.shape[1])==0).all(1).sum()
...:
In [323]: count2
Out[323]: 3
答案 2 :(得分:1)
这个怎么样?
from numpy import array
def createset(min, max):
xr = lambda: xrange(min, max)
return [ array([x, y, z]) for x in xr() for y in xr() for z in xr() ]
set1 = createset(-3, 3)
set2 = createset(-2, 1)
set3 = createset(-4, 5)
set4 = createset(0, 2)
lookup = {}
for x in set1:
for y in set2:
key = tuple(x + y)
if key not in lookup:
lookup[key] = 0
lookup[key] += 1
count = 0
for x in set3:
for y in set4:
key = tuple(-1 * (x + y))
if key in lookup:
count += lookup[key]
print count
想法是生成前两组的所有总和。然后,生成最后两组的总和,并查看查找表中是否有一个键,使得它们的总和为0.
答案 3 :(得分:0)
您可以在itertools.product
函数中使用sum
和生成器表达式:
from itertools import combinations
sum(1 for i in produt(set1,set2,set3,set4) if sum(i)==0)
这比你的代码快,但仍然是O(n 4 ),你可以get the product with Numpy而不是itertools提高速度。