我遇到的问题是我想要计算满足以下条件的组合数量:
a < b < a+d < c < b+d
其中a, b, c
是列表的元素,而d
是固定的增量。
这是一个香草实现:
def count(l, d):
s = 0
for a in l:
for b in l:
for c in l:
if a < b < a + d < c < b + d:
s += 1
return s
这是一个测试:
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
assert(32 == count(l, 4)) # Gone through everything by hand.
我怎样才能加快速度?我正在查看列表大小为2百万。
我正在处理[-pi,pi]范围内的浮动。例如,这会限制a < 0
。
我有一些实现,我构建了用于b
和c
的索引。但是,以下代码在某些情况下失败。 (即这是错误的)。
def count(l, d=pi):
low = lower(l, d)
high = upper(l, d)
s = 0
for indA in range(len(l)):
for indB in range(indA+1, low[indA]+1):
s += low[indB] + 1 - high[indA]
return s
def lower(l, d=pi):
'''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
Input must be sorted!
'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] < elem + d:
x += 1
if l[x-1] < elem + d:
ind.append(x-1)
else:
assert(x == length)
ind.append(x)
return ind
def upper(l, d=pi):
''' Returns first index where l[i] > l + d'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] <= elem + d:
x += 1
ind.append(x)
return ind
最初的问题来自众所周知的数学/综合竞赛。比赛要求您不要在网上发布解决方案。但它是从两周前开始的。
我可以用这个函数生成列表:
def points(n):
x = 1
y = 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = points(n)
angles.sort()
return count(angles, pi)
答案 0 :(得分:2)
from bisect import bisect_left, bisect_right
from collections import Counter
def count(l, d):
# cdef long bleft, bright, cleft, cright, ccount, s
s = 0
# Find the unique elements and their counts
cc = Counter(l)
l = sorted(cc.keys())
# Generate a cumulative sum array
cumulative = [0] * (len(l) + 1)
for i, key in enumerate(l, start=1):
cumulative[i] = cumulative[i-1] + cc[key]
# Pregenerate all the left and right lookups
lefthand = [bisect_right(l, a + d) for a in l]
righthand = [bisect_left(l, a + d) for a in l]
aright = bisect_left(l, l[-1] - d)
for ai in range(len(l)):
bleft = ai + 1
# Search only the values of a that have a+d in range
if bleft > aright:
break
# This finds b such that a < b < a + d.
bright = righthand[ai]
for bi in range(bleft, bright):
# This finds the range for c such that a+d < c < b+d.
cleft = lefthand[ai]
cright = righthand[bi]
if cleft != cright:
# Find the count of c elements in the range cleft..cright.
ccount = cumulative[cright] - cumulative[cleft]
s += cc[l[ai]] * cc[l[bi]] * ccount
return s
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
result = count(l, 4)
assert(32 == result)
testCount()
摆脱重复的,相同的值
仅迭代值
使用两个索引的累积计数来消除c
x + d
这不再是O(n^3)
,而是更像O(n ^ 2)`。
这显然还没有达到200万。以下是使用cython加速执行的较小浮点数据集(即很少或没有重复)的时间:
50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds
这是我的基准测试代码(不包括上面的操作代码):
from math import atan2, pi
def points(n):
x, y = 1, 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = sorted(points(n))
return count(angles, pi)
def test_large():
from datetime import datetime
for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
s = datetime.now()
C(n)
elapsed = datetime.now() - s
print("{1}: {0} seconds".format(elapsed, n))
if __name__ == '__main__':
testCount()
test_large()
答案 1 :(得分:2)
您的问题有一种方法可以产生O(n log n)
算法。设X
为值集。现在让我们修复b
。让A_b
为值集{ x in X: b - d < x < b }
,C_b
为值{ x in X: b < x < b + d }
的集合。如果我们能够快速找到|{ (x,y) : A_b X C_b | y > x + d }|
,我们就解决了这个问题。
如果我们对X
进行排序,我们可以将A_b
和C_b
表示为已排序数组的指针,因为它们是连续的。如果我们以非递减顺序处理b
个候选者,我们可以使用sliding window algorithm维护这些集合。它是这样的:
X
。设X = { x_1, x_2, ..., x_n }
,x_1 <= x_2 <= ... <= x_n
。left = i = 1
并设置right
以便C_b = { x_{i + 1}, ..., x_right }
。设置count = 0
i
到1
迭代n
。在每次迭代中,我们都会找到(a,b,c)
的有效三元组b = x_i
的数量。为此,请尽可能多地增加left
和right
,以便A_b = { x_left, ..., x_{i-1} }
和C_b = { x_{i + 1}, ..., x_right }
仍然有效。在此过程中,您基本上可以添加和删除虚构集A_b
和C_b
中的元素。
如果您删除或向其中一个集添加元素,请检查来自(a, c)
的{{1}} c > a + d
,a
和A_b
c
对C_b
你添加或销毁(这可以通过另一组中的简单二进制搜索来实现)。相应地更新count
,以便不变的count = |{ (x,y) : A_b X C_b | y > x + d }|
仍然存在。count
的值。这是最终结果。复杂性为O(n log n)
。
如果要使用此算法解决Euler问题,则必须避免出现浮点问题。我建议使用仅使用整数算术的自定义比较函数(使用2D矢量几何)按角度对点进行排序。实现|a-b| < d
比较也可以仅使用整数运算来完成。此外,由于您正在使用模2*pi
,因此您可能需要引入每个角度a
的三个副本:a - 2*pi
,a
和a + 2*pi
。然后,您只需在b
范围内查找[0, 2*pi)
,并将结果除以3。
UPDATE OP在Python中实现了这个算法。显然它包含一些错误,但它表明了一般的想法:
def count(X, d):
X.sort()
count = 0
s = 0
length = len(X)
a_l = 0
a_r = 1
c_l = 0
c_r = 0
for b in X:
if X[a_r-1] < b:
# find boundaries of A s.t. b -d < a < b
while a_r < length and X[a_r] < b:
a_r += 1 # This adds an element to A_b.
ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count += (ind - c_l)
while a_l < length and X[a_l] <= b - d:
a_l += 1 # This removes an element from A_b
ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count -= (c_r - ind)
# Find boundaries of C s.t. b < c < b + d
while c_l < length and X[c_l] <= b:
c_l += 1 # this removes an element from C_b
ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count -= (ind - a_l)
while c_r < length and X[c_r] < b + d:
c_r += 1 # this adds an element to C_b
ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count += (ind - a_l)
s += count
return s
答案 2 :(得分:1)
由于l
已排序且a < b < c
必须为true,因此您可以使用itertools.combinations()
来减少循环次数:
sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
查看组合只会将此循环减少到816次迭代。
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32
a < b
测试是多余的。
答案 3 :(得分:1)
1)为了减少每个级别的迭代次数,您可以从列表中删除不通过每个级别上的条件的元素
2)将set
与collections.counter
一起使用,可以通过删除重复项来减少迭代次数:
from collections import Counter
def count(l, d):
n = Counter(l)
l = set(l)
s = 0
for a in l:
for b in (i for i in l if a < i < a+d):
for c in (i for i in l if a+d < i < b+d):
s += (n[a] * n[b] * n[c])
return s
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32
测试版本的迭代次数(a,b,c):
>>> count1(l, 4)
18 324 5832
我的版本:
>>> count2(l, 4)
9 16 7
答案 4 :(得分:0)
基本思路是:
因此,您可以无条件地增加s,性能大致为O(N),N为数组的大小。
import collections
def count(l, d):
s = 0
# at first we get rid of repeated items
counter = collections.Counter(l)
# sort the list
uniq = sorted(set(l))
n = len(uniq)
# kad is the index of the first element > a+d
kad = 0
# ka is the index of a
for ka in range(n):
a = uniq[ka]
while uniq[kad] <= a+d:
kad += 1
if kad == n:
return s
for kb in range( ka+1, kad ):
# b only runs in the range [a..a+d)
b = uniq[kb]
if b >= a+d:
break
for kc in range( kad, n ):
# c only rund from (a+d..b+d)
c = uniq[kc]
if c >= b+d:
break
print( a, b, c )
s += counter[a] * counter[b] * counter[c]
return s
编辑:对不起,我把提交搞砸了。固定的。