我希望列表中的元素数量满足另一个列表指定的某些条件。我的方式是使用sum
和any
。简单的测试代码是:
>>> x1 = list(xrange(300))
>>> x2 = [random.randrange(20, 50) for i in xrange(30)]
>>> def test():
ns = []
for i in xrange(10000):
ns.append(sum(1 for j in x2 if any(abs(k-j)<=10 for k in x1)))
return ns
使用Profiler表示sum
和any
导致了最多的时间,有什么方法可以改善这一点吗?
>>> cProfile.run('ns = test()')
8120003 function calls in 0.699 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.003 0.003 1.552 1.552 <pyshell#678>:2(test)
310000 0.139 0.000 1.532 0.000 <pyshell#678>:5(<genexpr>)
1 0.000 0.000 1.552 1.552 <string>:1(<module>)
7490000 0.196 0.000 0.196 0.000 {abs}
300000 0.345 0.000 1.377 0.000 {any}
10000 0.001 0.000 0.001 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
10000 0.016 0.000 1.548 0.000 {sum}
函数test
仅包含10000次迭代。通常,我会进行数万次迭代,并且使用cProfile.run
表明此块会导致大部分执行时间。
=============================================== ====================
修改
根据@DavisHerring's回答,使用二分搜索。
from _bisect import *
>>> x1 = list(xrange(300))
>>> x2 = [random.randrange(20, 50) for i in xrange(30)]
>>> def testx():
ns = []
x2k = sorted(x2)
x1k = sorted(x1)
for i in xrange(10000):
bx = [bisect_left(x1k, xk) for xk in x2k]
rn = sum(1 if k==0 and x1k[k]-xk<=10
else 1 if k==len(x1k) and xk-x1k[k-1]<=10
else xk-x1k[k-1]<=10 or x1k[k]-xk<=10
for k, xk in zip(bx, x2k))
ns.append(rn)
return ns
根据cProfile.run
,达到0.196 seconds
,速度提高3倍+。
答案 0 :(得分:2)
谓词的本质是至关重要的;因为它是一条线的距离,您可以为您的数据提供相应的结构以加快搜索速度。有几种变化:
对列表x1
进行排序:然后您可以使用二进制搜索来查找最近的值并检查它们是否足够接近。
如果列表x2
更长,并且其大部分元素不在范围内,则可以通过对其进行排序并搜索每个可接受间隔的开始和结束来使其快一些。
如果您对两个列表进行排序,您可以一起单步执行它们并在线性时间内完成。这是渐近等效的,除非有其他理由对它们进行排序,当然。
答案 1 :(得分:0)
使用interval tree数据结构。适合您需求的非常简单的实现可以如下:
class SimpleIntervalTree:
def __init__(self, points, radius):
intervals = []
l, r = None, None
for p in sorted(points):
if r is None or r < p - radius:
if r is not None:
intervals.append((l, r))
l = p - radius
r = p + radius
if r is not None:
intervals.append((l, r))
self._tree = self._to_tree(intervals)
def _to_tree(self, intervals):
if len(intervals) == 0:
return None
i = len(intervals) // 2
return {
'left': self._to_tree(intervals[0:i]),
'value': intervals[i],
'right': self._to_tree(intervals[i + 1:])
}
def __contains__(self, item):
t = self._tree
while t is not None:
l, r = t['value']
if item < l:
t = t['left']
elif item > r:
t = t['right']
else:
return True
return False
然后你的代码看起来像这样:
x1 = list(range(300))
x2 = [random.randrange(20, 50) for i in range(30)]
it = SimpleIntervalTree(x1, 10)
def test():
ns = []
for i in range(10000):
ns.append(sum(1 for j in x2 if j in it))
return ns
在__init__
中,点列表首先转换为连续间隔列表。接下来,将间隔放在平衡的binary search tree中。在该树中,每个节点包含一个间隔,节点的每个左子树包含较低的间隔,并且节点的每个右子树包含较高的间隔。这样,每当我们想要测试一个点是否在任何段(__contains__
)中时,我们就从根开始执行二进制搜索。