如何计算组合数?

时间:2014-02-08 15:43:31

标签: python algorithm combinations

我遇到的问题是我想要计算满足以下条件的组合数量:

 a < b < a+d < c < b+d

其中a, b, c是列表的元素,而d是固定的增量。

这是一个香草实现:

def count(l, d):
    s = 0
    for a in l:
        for b in l:
            for c in l:
                if a < b < a + d < c < b + d:
                    s += 1
    return s

这是一个测试:

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    assert(32 == count(l, 4)) # Gone through everything by hand.

问题

我怎样才能加快速度?我正在查看列表大小为2百万。

补充资料

我正在处理[-pi,pi]范围内的浮动。例如,这会限制a < 0

到目前为止我所拥有的:

我有一些实现,我构建了用于bc的索引。但是,以下代码在某些情况下失败。 (即这是错误的)。

def count(l, d=pi):
    low = lower(l, d)
    high = upper(l, d)
    s = 0
    for indA in range(len(l)):
            for indB in range(indA+1, low[indA]+1):
                    s += low[indB] + 1 - high[indA]
    return s

def lower(l, d=pi):
    '''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
    Input must be sorted!
    '''
    ind = []
    x = 0
    length = len(l)
    for  elem in l:
        while x < length and l[x] < elem + d:
            x += 1
        if l[x-1] < elem + d:
            ind.append(x-1)
        else:
            assert(x == length)
            ind.append(x)
    return ind


def upper(l, d=pi):
    ''' Returns first index where l[i] > l + d'''
    ind = []
    x = 0
    length = len(l)
    for elem in l:
        while x < length and l[x] <= elem + d:
            x += 1
        ind.append(x)
    return ind

原始问题

最初的问题来自众所周知的数学/综合竞赛。比赛要求您不要在网上发布解决方案。但它是从两周前开始的。

我可以用这个函数生成列表:

def points(n):
    x = 1
    y = 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = points(n)
    angles.sort()
    return count(angles, pi)

5 个答案:

答案 0 :(得分:2)

from bisect import bisect_left, bisect_right
from collections import Counter

def count(l, d):
    # cdef long bleft, bright, cleft, cright, ccount, s
    s = 0

    # Find the unique elements and their counts
    cc = Counter(l)

    l = sorted(cc.keys())

    # Generate a cumulative sum array
    cumulative = [0] * (len(l) + 1)
    for i, key in enumerate(l, start=1):
        cumulative[i] = cumulative[i-1] + cc[key]

    # Pregenerate all the left and right lookups
    lefthand = [bisect_right(l, a + d) for a in l]
    righthand = [bisect_left(l, a + d) for a in l]

    aright = bisect_left(l, l[-1] - d)
    for ai in range(len(l)):
        bleft = ai + 1
        # Search only the values of a that have a+d in range
        if bleft > aright:
            break
        # This finds b such that a < b < a + d.
        bright = righthand[ai]
        for bi in range(bleft, bright):
            # This finds the range for c such that a+d < c < b+d.
            cleft = lefthand[ai]
            cright = righthand[bi]
            if cleft != cright:
                # Find the count of c elements in the range cleft..cright.
                ccount = cumulative[cright] - cumulative[cleft]
                s += cc[l[ai]] * cc[l[bi]] * ccount
    return s

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    result = count(l, 4)
    assert(32 == result)

testCount()
  1. 摆脱重复的,相同的值

  2. 仅迭代值

  3. 所需的范围
  4. 使用两个索引的累积计数来消除c

  5. 上的循环
  6. x + d

  7. 上的缓存查找

    这不再是O(n^3),而是更像O(n ^ 2)`。

    这显然还没有达到200万。以下是使用cython加速执行的较小浮点数据集(即很少或没有重复)的时间:

    50: 0:00:00.157849 seconds
    100: 0:00:00.003752 seconds
    200: 0:00:00.022494 seconds
    400: 0:00:00.071192 seconds
    800: 0:00:00.253750 seconds
    1600: 0:00:00.951133 seconds
    3200: 0:00:03.508596 seconds
    6400: 0:00:10.869102 seconds
    12800: 0:00:55.986448 seconds
    

    这是我的基准测试代码(不包括上面的操作代码):

    from math import atan2, pi
    
    def points(n):
        x, y = 1, 1
        for _ in range(n):
            x = (x * 1248) % 32323
            y = (y * 8421) % 30103
            yield atan2(x - 16161, y - 15051)
    
    def C(n):
        angles = sorted(points(n))
        return count(angles, pi)
    
    def test_large():
        from datetime import datetime
        for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
            s = datetime.now()
            C(n)
            elapsed = datetime.now() - s
            print("{1}: {0} seconds".format(elapsed, n))
    
    if __name__ == '__main__':
        testCount()
        test_large()
    

答案 1 :(得分:2)

您的问题有一种方法可以产生O(n log n)算法。设X为值集。现在让我们修复b。让A_b为值集{ x in X: b - d < x < b }C_b为值{ x in X: b < x < b + d }的集合。如果我们能够快速找到|{ (x,y) : A_b X C_b | y > x + d }|,我们就解决了这个问题。

如果我们对X进行排序,我们可以将A_bC_b表示为已排序数组的指针,因为它们是连续的。如果我们以非递减顺序处理b个候选者,我们可以使用sliding window algorithm维护这些集合。它是这样的:

  1. 排序X。设X = { x_1, x_2, ..., x_n }x_1 <= x_2 <= ... <= x_n
  2. 设置left = i = 1并设置right以便C_b = { x_{i + 1}, ..., x_right }。设置count = 0
  3. i1迭代n。在每次迭代中,我们都会找到(a,b,c)的有效三元组b = x_i的数量。为此,请尽可能多地增加leftright,以便A_b = { x_left, ..., x_{i-1} }C_b = { x_{i + 1}, ..., x_right }仍然有效。在此过程中,您基本上可以添加和删除虚构集A_bC_b中的元素。 如果您删除或向其中一个集添加元素,请检查来自(a, c)的{​​{1}} c > a + daA_b cC_b你添加或销毁(这可以通过另一组中的简单二进制搜索来实现)。相应地更新count,以便不变的count = |{ (x,y) : A_b X C_b | y > x + d }|仍然存在。
  4. 在每次迭代中总结count的值。这是最终结果。
  5. 复杂性为O(n log n)

    如果要使用此算法解决Euler问题,则必须避免出现浮点问题。我建议使用仅使用整数算术的自定义比较函数(使用2D矢量几何)按角度对点进行排序。实现|a-b| < d比较也可以仅使用整数运算来完成。此外,由于您正在使用模2*pi,因此您可能需要引入每个角度a的三个副本:a - 2*piaa + 2*pi。然后,您只需在b范围内查找[0, 2*pi),并将结果除以3。

    UPDATE OP在Python中实现了这个算法。显然它包含一些错误,但它表明了一般的想法:

    def count(X, d):
        X.sort()
        count = 0
        s = 0
        length = len(X)
        a_l = 0
        a_r = 1
        c_l = 0
        c_r = 0
        for b in X:
            if X[a_r-1] < b:
                # find boundaries of A s.t. b -d < a < b
                while a_r < length and X[a_r] < b:
                    a_r += 1  # This adds an element to A_b. 
                    ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
                    if c_l <= ind < c_r:
                        count += (ind - c_l)
                while a_l < length and X[a_l] <= b - d:
                    a_l += 1  # This removes an element from A_b
                    ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
                    if c_l <= ind < c_r:
                        count -= (c_r - ind)
                # Find boundaries of C s.t. b < c < b + d
                while c_l < length and X[c_l] <= b:
                    c_l += 1  # this removes an element from C_b
                    ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
                    if a_l <= ind <= a_r:
                        count -= (ind - a_l)
                while c_r  < length and X[c_r] < b + d:
                    c_r += 1 # this adds an element to C_b
                    ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
                    if a_l <= ind <= a_r:
                        count += (ind - a_l)
                s += count
        return s
    

答案 2 :(得分:1)

由于l已排序且a < b < c必须为true,因此您可以使用itertools.combinations()来减少循环次数:

sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)

查看组合只会将此循环减少到816次迭代。

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32

a < b测试是多余的。

答案 3 :(得分:1)

1)为了减少每个级别的迭代次数,您可以从列表中删除不通过每个级别上的条件的元素
2)将setcollections.counter一起使用,可以通过删除重复项来减少迭代次数:

from collections import Counter
def count(l, d):
    n = Counter(l)
    l = set(l)
    s = 0
    for a in l:
        for b in (i for i in l if a < i < a+d):
            for c in (i for i in l if a+d < i < b+d):
                s += (n[a] * n[b] * n[c])
    return s

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32

测试版本的迭代次数(a,b,c):

>>> count1(l, 4)
18 324 5832

我的版本:

>>> count2(l, 4)
9 16 7

答案 4 :(得分:0)

基本思路是:

  1. 摆脱重复的,相同的价值
  2. 让每个值仅迭代它必须迭代的范围。
  3. 因此,您可以无条件地增加s,性能大致为O(N),N为数组的大小。

    import collections
    
    def count(l, d):
        s = 0
        # at first we get rid of repeated items
        counter = collections.Counter(l)
        # sort the list
        uniq = sorted(set(l))
        n = len(uniq)
        # kad is the index of the first element > a+d
        kad = 0 
        # ka is the index of a
        for ka in range(n):
            a = uniq[ka]
            while uniq[kad] <= a+d:
                kad += 1
                if kad == n:
                    return s
    
            for kb in range( ka+1, kad ):
                # b only runs in the range [a..a+d)
                b = uniq[kb]
                if b  >= a+d:
                    break
                for kc in range( kad, n ):
                    # c only rund from (a+d..b+d)
                    c = uniq[kc]
                    if c >= b+d:
                        break
                    print( a, b, c )
                    s += counter[a] * counter[b] * counter[c]
        return s
    
    编辑:对不起,我把提交搞砸了。固定的。