为什么"地图" ThreeSum的版本这么慢?

时间:2014-09-07 16:50:05

标签: python performance

我预计ThreeSum的这个Python实现会很慢:

def count(a):
       """ThreeSum: Given N distinct integers, how many triples sum to exactly zero?"""
       N = len(a)
       cnt = 0
       for i in range(N):
         for j in range(i+1, N):
           for k in range(j+1, N):
             if sum([a[i], a[j], a[k]]) == 0:
               cnt += 1
       return cnt 

但我感到震惊的是,这个版本看起来也很慢:

def count_python(a):
  """ThreeSum using itertools"""
  return sum(map(lambda X: sum(X)==0, itertools.combinations(a, r=3))) 

有人可以推荐更快的Python实现吗?这两种实现看起来都很慢...... 感谢

...

答案摘要: 以下是在我的机器上运行O(N ^ 3)(用于教育目的,未在现实生活中使用)版本的此线程中提供的所有各种版本的运行:

56秒RUNNING count_slow ...
28秒RUNNING count_itertools,由Ashwini Chaudhary撰写...
14秒RUNNING count_fixed,由roippi写的...
11秒运行count_itertools(更快),由Veedrak撰写...
08秒RUNNING count_enumerate,由roippi写的......

*注意:需要修改Veedrak的解决方案以获得正确的计数输出:
sum(如果x + y == - z,则在itertools.combinations(a,r = 3)中为x,y,z为1)

3 个答案:

答案 0 :(得分:3)

在算法上,你的函数的两个版本都是O(n ** 3) - 所以渐近地两者都不是优越的。你会发现itertools版本实际上有点快,因为它花费更多的时间在C中循环而不是在python字节码中循环。你可以通过完全删除map来降低几个百分点(特别是如果你正在运行py2),但它仍然会慢下来#34;与在JVM中运行它所获得的任何时间相比。

请注意,除了cPython之外还有很多python实现 - 对于循环代码,pypy往往比cPython更快 。因此,我不会将python-as-a-language编写为缓慢,必然,但我肯定会说python的参考实现并不以其炽热的循环速度而闻名。如果你关心的话,可以给其他蟒蛇口味。

特定于您的算法,优化将允许您将其降至O(n ** 2)。构建一组整数s,并构建所有对(a,b)。你知道你可以"归零" (a+b)当且仅当-(a+b) in (s - {a,b})

感谢@Veedrak:遗憾的是构建s - {a,b}是一个缓慢的O(len(s))操作本身 - 所以只需检查-(a+b)是否等于a或{{1 }}。如果是,您知道没有第三个b可以履行c,因为您输入中的所有数字都是不同的。

a+b+c == 0

注意最后三分;这是因为每个成功的组合都是三重计数。可以避免这种情况,但它实际上并没有加快速度,而且(imo)只会使代码复杂化。

好奇的一些时间:

def count_python_faster(a):
      s = frozenset(a)
      return sum(1 for x,y in itertools.combinations(a,2)
             if -(x+y) not in (x,y) and -(x+y) in s) // 3

答案 1 :(得分:3)

提供第二个答案。从各种评论中,看起来你主要关注为什么这个特定的O(n ** 3)算法在从java移植时很慢。让我们潜入。

def count(a):
       """ThreeSum: Given N distinct integers, how many triples sum to exactly zero?"""
       N = len(a)
       cnt = 0
       for i in range(N):
         for j in range(i+1, N):
           for k in range(j+1, N):
             if sum([a[i], a[j], a[k]]) == 0:
               cnt += 1
       return cnt

立即弹出的一个主要问题是你正在做一些你的java代码几乎肯定没做的事情:实现一个3元素列表只是为了将三个数字加在一起!

if sum([a[i], a[j], a[k]]) == 0:

呸!只需将其写为

即可
if a[i] + a[j] + a[k] == 0:

一些基准测试显示,您正在通过这样做添加50%以上的开销 。让人惊讶。


这里的另一个问题是你正在使用索引,你应该使用迭代。在python中尽量避免编写这样的代码:

for i in range(len(some_list)):
    do_something(some_list[i])

而只是写:

for x in some_list:
    do_something(x)

如果你明确需要你所使用的索引(正如你在代码中所做的那样),请使用enumerate

for i,x in enumerate(some_list):
    #etc

一般来说,这是一种风格的东西(虽然它比那更深入,有鸭子类型和迭代器协议) - 但它也是一种表现的东西。为了查找a[i]的值,该调用转换为a.__getitem__(i),然后python必须动态解析__getitem__方法查找,调用它并返回值。每次。这不是一个疯狂的开销 - 至少在内置类型上 - 但如果你在循环中做了很多,它会增加。另一方面,将a视为可迭代,可以避免大量的开销。

所以记住这个改变,你可以再次重写你的功能:

def count_enumerate(a):
    cnt = 0
    for i, x in enumerate(a):
        for j, y in enumerate(a[i+1:], i+1):
            for z in a[j+1:]:
                if x + y + z == 0:
                    cnt += 1
    return cnt

让我们看看一些时间:

%timeit count(range(-100,100))
1 loops, best of 3: 394 ms per loop

%timeit count_fixed(range(-100,100)) #just fixing your sum() line
10 loops, best of 3: 158 ms per loop

%timeit count_enumerate(range(-100,100))
10 loops, best of 3: 88.9 ms per loop

这和它的速度一样快。通过将所有内容包含在理解中而不是执行cnt += 1,您可以减少一个百分比左右,但这非常小。

我玩过几个itertools实现,但实际上我不能让它们比这个显式循环版本更快。如果您考虑它,这是有道理的 - 对于每次迭代,itertools.combinations版本必须重新绑定所有三个变量所引用的内容,而显式循环则“欺骗”并重新绑定变量在外圈很少见。

现实检查时间:在完成所有操作之后,您仍然可以期望cPython运行此算法比现代JVM慢一个数量级。 python中内置了太多的抽象,阻碍了快速循环。如果您关心速度(并且无法修复算法 - 请参阅我的其他答案),请使用类似numpy之类的内容将所有时间用在C中循环,或者使用不同的python实现。


postscript:pypy

为了好玩,我在cPython和pypy上的1000个元素列表上运行count_fixed

<强> CPython的:

In [81]: timeit.timeit('count_fixed(range(-500,500))', setup='from __main__ import count_fixed', number = 1)
Out[81]: 19.230753898620605

<强> pypy:

>>>> timeit.timeit('count_fixed(range(-500,500))', setup='from __main__ import count_fixed', number = 1)
0.6961538791656494

迅速!

我可能会在稍后添加一些java测试来比较: - )

答案 2 :(得分:2)

您尚未说明您使用的是哪个版本的Python。

在Python 3.x中,生成器表达式比您列出的两个实现中的任何一个快约10%。对a使用[-100,100]范围内的100个数字的随机数组:

count(a)           -> 8.94 ms  # as per your implementation
count_python(a)    -> 8.75 ms  # as per your implementation

def count_generator(a):
    return sum((sum(x) == 0 for x in itertools.combinations(a,r=3)))

count_generator(a) -> 7.63 ms

但除此之外,它是主导执行时间的组合的剪切量 - O(N ^ 3)。

我应该添加上面显示的时间,每个循环包含10个调用,平均超过10个循环。是的,我的笔记本电脑也很慢:)