Python:如何有效地计算条件和?

时间:2013-12-27 04:44:33

标签: python performance loops iterator

编辑:我正在研究性能敏感的情况,需要使用用户定义的检查点计算数据的总和或最大值。请参阅演示代码:

from itertools import izip
timestamp=[1,2,3,4,...]#len(timestamp)=N
checkpoints=[1,3,5,7,..]#user defined
data=([1,1,1,1,...],
      [2,2,2,2,...],
      ...)#len(data)=M,len(data[any])=N
processtype=('sum','max','min','snapshot',...)#len(processtype)=M

def processdata(timestamp, checkpoints, data, processtype):
    checkiter=iter(checkpoints)
    checher=checkiter.next()
    tmp=[0 if t=='sum' else None for t in processtype]
    for x, d in izip(timestamp,izip(*data)):
        tmp =[tmp[i]+d[i] if t=='sum' else
              d[i] if (t=='snapshot'
                   or (tmp[i] is None)
                   or (t=='max' and tmp[i]<d[i])
                   or (t=='min' and tmp[i]>d[i])) else
              tmp[i] for (i,t) in enumerate(processtype)]
        if x>checher:
            yield (checher,tmp)
            checher=checkiter.next()
            tmp=[0 if t=='sum' else None for t in processtype]

基准测试的原始演示:

def speratedsum(iter, condition):
    tmp=0
    for x in iter:
        if condition(x):
            yield tmp
            tmp=0
        else:
            tmp+=x

编辑:感谢@ M4rtini和@Chronial我在以下测试代码上运行了banchmark:

from timeit import timeit

it=xrange(100001)
condition=lambda x: x % 100 == 0

def speratedsum(it, condition):
    tmp=0
    for x in it:
        if condition(x):
            yield tmp+x
            tmp=0
        else:
            tmp+=x

def test1():
    return list(speratedsum(it,condition))

def red_func2(acc, x):
    if condition(x):
        acc[0].append(acc[1]+x)
        return (acc[0], 0)
    else:
        return (acc[0], acc[1] + x)

def test2():
    return reduce(red_func2, it,([], 0))[0]

def red_func3(l, x):
    if condition(x):
        l[-1] += x
        l.append(0)
    else:
        l[-1] += x
    return l

def test3():
    return reduce(red_func3, it, [0])[:-1]

import itertools
def test4():
    groups = itertools.groupby(it, lambda x: (x-1) / 100)
    return map(lambda g: sum(g[1]), groups)

import numpy as np
import numba
@numba.jit(numba.int_[:](numba.int_[:],numba.int_[:]),
           locals=dict(si=numba.int_,length=numba.int_))
def jitfun(arr,con):    
    length=arr.shape[0]
    out=np.zeros(con.shape[0],int)
    si=0
    for i in range(length):        
        out[si]+=arr[i]
        if(arr[i]>=con[si]):
            si+=1
    return out

conditionlist=[x for x in it if condition(x)]
a=np.array(it, int)
c=np.array(conditionlist,int)
def test5():
    return list(jitfun(a,c))
test5() #warm up for JIT

time1=timeit(test1,number=100)
time2=timeit(test2,number=100)
time3=timeit(test3,number=100)
time4=timeit(test4,number=100)
time5=timeit(test5,number=100)

print "test1:",test1()==test1(),time1/time1
print "test2:",test1()==test2(),time1/time2
print "test3:",test1()==test3(),time1/time3
print "test4:",test1()==test4(),time1/time4
print "test5:",test1()==test5(),time1/time5

输出:

test1: True 1.0
test2: True 0.369117307201
test3: True 0.496470798051
test4: True 0.833137283359
test5: True 34.1052257366

您对我应该寻求的地方有什么建议吗?谢谢!

编辑:我设法使用带有回调的numba解决方案来替换yield,这是在这里真正起作用的最省力的解决方案。所以接受了@ M4rtini的回答。但是要小心numba的局限性。经过我2天的尝试,numba可以增强numpy数组索引迭代性能,但仅此而已。

5 个答案:

答案 0 :(得分:2)

您似乎非常确定这是程序的缓慢部分,但标准建议是编写可读性,然后在必要时根据需要修改性能 - 在分析之后。

这是我前段时间写的关于加快Python速度的页面: http://stromberg.dnsalias.org/~dstromberg/speeding-python/

如果您没有使用任何第三方C扩展模块,Pypy可能是您的最佳选择。如果您使用的是第三方C扩展模块,请查看numba和/或Cython。

答案 1 :(得分:1)

为了完成它,这里是一个使用reduce的实现(虽然应该有可怕的性能):)

res = reduce(lambda acc, x:
            (acc[0] + [acc[1]], 0) if condition(x) else
            (acc[0], acc[1] + x),
            iter,
            ([], 0))[0]

这应该快得多,但我不是那么“干净”,因为它会改变累积列表。

def red_func(l, x):
    if condition(x):
        l.append(0)
    else:
        l[-1] = l[-1] + x
    return l
res = reduce(red_func, iter, [0])[:-1]

答案 2 :(得分:1)

您的原始版本可以通过groupby

解决
for key, group in itertools.groupby(iter, condition):
    if not key:
        yield sum(group)

这假定条件返回TrueFalse或其他一组两种可能性。如果它可以返回0,1,2,3或类似的东西,你会想要首先将回报转换为bool

for key, group in itertools.groupby(iter, lambda x: bool(condition(x))):
    #...

groupby会将具有相同键的项目按顺序分组到一个组中。在这里,我们将条件下False的连续项集合在一起,然后产生组的总和。

这确实错过了在原始版本为0的情况下连续两个项目为True的情况。

答案 3 :(得分:1)

以下是使用itertools.groupbyitertools.imap的解决方案:

iter = xrange(0, 10000)
groups = itertools.groupby(iter, lambda x: x / 100)
sums = itertools.imap(lambda g: sum(list(g[1])[1:]), groups)

请注意,它会产生略微不同的结果;结果列表中不会有前导零,并且它会产生一个额外的组,因为您没有产生最后一个组。

答案 4 :(得分:1)

import numba
@numba.autojit
def speratedsum2():
    s = 0
    tmp=0
    for x in xrange(10000):
        if x % 100 == 0:
            s += tmp
            tmp=0
        else:
            tmp+=x
    return s


In [140]: %timeit sum([x for x in speratedsum1()])
1000 loops, best of 3: 625 µs per loop

In [142]: %timeit speratedsum2()
10000 loops, best of 3: 113 µs per loop