使用itertools.combinations的最快方法

时间:2013-12-29 19:40:46

标签: python numpy

我需要加快以下功能:

import numpy as np
import itertools
import timeit

def combcol(myarr):
    ndims = myarr.shape[0]
    solutions = []
    for idx1, idx2, idx3, idx4, idx5, idx6 in itertools.combinations(np.arange(ndims), 6):
        c1, c2, c3, c4, c5, c6 = myarr[idx1,1], myarr[idx2,2], myarr[idx3,1], myarr[idx4,2], myarr[idx5,1], myarr[idx6,2]
        if c1-c2>0 and c2-c3<0 and c3-c4>0 and c4-c5<0 and c5-c6>0 :
            solutions.append(((idx1, idx2, idx3, idx4, idx5, idx6),(c1, c2, c3, c4, c5, c6)))
    return solutions

X = np.random.random((20, 10))  
Y = np.random.random((40, 10))  


if __name__=='__main__':
    from timeit import Timer
    t = Timer(lambda : combcol(X))
    t1 = Timer(lambda : combcol(Y))
    print('t : ',t.timeit(number=1),'t1 : ',t1.timeit(number=1))

结果:

t :  0.6165180211451455 t1 :  64.49216925614847 

算法对于我的标准使用来说太慢了(myarr.shape [0] = 500)。 是否有NumPy方法来减少此功能的执行时间(不浪费太多内存)? 是否可以在Cython中实现该问题?

我尝试过使用cProfile查看哪些部分很慢。 这里的大部分时间都花在调用combcol()上。

import profile
........
........
profile.run('print(len(combcol(Y))); print')

144547
         144559 function calls in 39.672 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   144547    0.641    0.000    0.641    0.000 :0(append)
        1    0.000    0.000    0.000    0.000 :0(arange)
        2    0.000    0.000    0.000    0.000 :0(charmap_encode)
        1    0.000    0.000   39.672   39.672 :0(exec)
        1    0.000    0.000    0.000    0.000 :0(len)
        1    0.000    0.000    0.000    0.000 :0(print)
        1    0.000    0.000    0.000    0.000 :0(setprofile)
        1    0.094    0.094   39.672   39.672 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 cp850.py:18(encode)
        1   38.938   38.938   39.578   39.578 essaiNumpy4.py:13(combcol)
        1    0.000    0.000   39.672   39.672 profile:0(print(len(combcol(Y))); print)
        0    0.000             0.000          profile:0(profiler)

最后我修改了这样的代码:

def combcol2(myarr):
    ndims = myarr.shape[0]
    myarr1 = myarr[:,1].tolist()
    myarr2 = myarr[:,2].tolist()
    solutions = []
    for idx1, idx2, idx3, idx4, idx5, idx6 in itertools.combinations(range(ndims), 6):
        if myarr1[idx1] > myarr2[idx2] < myarr1[idx3] > myarr2[idx4] < myarr1[idx5] > myarr2[idx6]:
            solutions.append(((idx1, idx2, idx3, idx4, idx5, idx6),(myarr1[idx1], myarr2[idx2], myarr1[idx3], myarr2[idx4], myarr1[idx5], myarr2[idx6])))
    return solutions

X = np.random.random((40, 10))

if __name__=='__main__':
    from timeit import Timer
    t = Timer(lambda : combcol2(X))
    print('t : ',t.timeit(number=1))

结果:

t :  4.341582240200919

3 个答案:

答案 0 :(得分:2)

我不知道你到底想要得到什么,但是根据你的代码,你的解决方案的约束,idx1 ... idx6的形式

X[idx1, 1] > X[idx2, 2]
X[idx3, 1] > X[idx2, 2]

X[idx3, 1] > X[idx4, 2]
X[idx5, 1] > X[idx4, 2]

X[idx5, 1] > X[idx6, 2]

你可以获得(几乎立即为形状&lt; = 500)这个形式的所有可加上对(idx1, idx2), (idx3, idx2), (idx3, idx4), (idx5, idx4), idx(5, idx6)

idx56 = lst = zip(*np.where((X[:,1].reshape(-1,1)>X[:,2].reshape(1,-1))))

然后你可以用:

生成所有可能的索引
import operator 
slst = sorted(lst, key=operator.itemgetter(1))
grps = [[x, list(g)] for x, g in itertools.groupby(slst, key=operator.itemgetter(1))]
idx345 = idx123 = [[x1, x2, x3] for _, g in grps for ((x1, x2), (x3, _)) in itertools.combinations(g, 2)]

在我的测试机器上计算这些列表花了几秒钟(idx345的长度超过200亿个项目)。你只需要在x3 = x3,x5 = x5上加入这些列表(我认为解决方案的大小会随着这个连接而显着增长,所以无法猜测它是否有用)。

答案 1 :(得分:2)

alko已经为您的计划列出了有用的改革,Tim Peters指出,500-choose-6超过21万亿(即21057686727000)。这个答案将指出原始程序的简单加速。 (我认为与alko的方法相比,这是一个小的加速,但以下值得注意未来的python编程。)

您的选择陈述是
if c1-c2>0 and c2-c3<0 and c3-c4>0 and c4-c5<0 and c5-c6>0 :
这相当于
if c1>c2 and c2<c3 and c3>c4 and c4<c5 and c5>c6 :
但在python解释器中,前者比后者长67%。例如,这两个案例的一些示例输出,在我的Intel i3-2120机器上(显然比你的机器快几倍)运行Python 2.7.5+:
('t : ', 0.12977099418640137, 't1 : ', 14.45378589630127)
('t : ', 0.0887291431427002, 't1 : ', 8.54729700088501)
4次此类运行的平均值比率为14.529 / 8.709 = 1.668。

答案 2 :(得分:1)

您可以使用递归迭代组合,而不是使用itetools.combinations()。这样做的好处是您可以测试,例如c1>c2,如果False,您可以跳过所有相关组合。