Python中的段树实现

时间:2016-12-11 11:18:45

标签: python algorithm segment-tree

我正在使用this解决segment tree问题,但我遇到时间限制错误。 以下是我的范围最小查询的原始代码,在我的代码中将min更改为max可以解决上述问题。我不知道如何提高代码的性能。你能帮我解决一下性能问题吗?

t = [None] * 2 * 7      # n is length of list


def build(a, v, start, end):
    '''
    A recursive function that constructs Segment Tree for list a.
    v is the starting node
    start and end are the index of array
    '''

    n = len(a)
    if start == end:
        t[v] = a[start]
    else:
        mid = (start + end) / 2
        build(a, v * 2, start, mid)     # v*2 is left child of parent v
        # v*2+1 is the right child of parent v
        build(a, v * 2 + 1, mid + 1, end)
        t[v] = min(t[2 * v], t[2 * v + 1])
    return t

print build([18, 17, 13, 19, 15, 11, 20], 1, 0, 6)

inf = 10**9 + 7


def range_minimum_query(node, segx, segy, qx, qy):
    '''
    returns the minimum number in range(qx,qy)
    segx and segy represent the segment index

    '''
    if qx > segy or qy < segx:      # query out of range
        return inf
    elif segx >= qx and segy <= qy:  # query range inside segment range
        return t[node]
    else:
        return min(range_minimum_query(node * 2, segx, (segx + segy) / 2, qx, qy), range_minimum_query(node * 2 + 1, ((segx + segy) / 2) + 1, segy, qx, qy))

print range_minimum_query(1, 1, 7, 1, 3)

# returns 13

这可以迭代实现吗?

2 个答案:

答案 0 :(得分:11)

语言选择

首先,如果你使用python,你可能永远不会通过评分者。如果你在这里查看所有过去解决方案的状态http://www.spoj.com/status/GSS1/start=0,你会发现每个几乎所有被接受的解决方案都是用C ++编写的。我认为您无法选择使用C ++。注意时间限制是0.115s-0.230s。这是一个“仅适用于C / C ++”的时间限制。对于将接受其他语言解决方案的问题,时间限制将是一个“圆”数,如1秒。在这种类型的环境中,Python比C ++慢约2-4倍。

细分树实施问题......?

其次,我不确定您的代码是否实际构建了一个分段树。具体来说,我不明白为什么这条线存在:

t[v]=min(t[2*v],t[2*v+1]) 

我很确定段树中的节点存储其子节点的总和,因此如果您的实现接近正确,我认为它应该改为

t[v] = t[2*v] + t[2*v+1]

如果您的代码是“正确的”,那么如果您甚至不存储间隔总和,我会质疑如何在[x_i, y_i]范围内找到最大间隔总和。

迭代段树

第三,可以迭代地实现分段树。这是C ++中的教程:http://codeforces.com/blog/entry/18051

段树不应该足够快......

最后,我不明白分段树如何帮助您解决此问题。通过细分树,您可以查询log(n)中范围的总和。此问题要求任何范围的最大可能总和。我没有听说过允许“范围最小查询”或“范围最大查询”的分段树。

对于1个查询,一个天真的解决方案将是O(n ^ 3)(尝试所有n ^ 2个可能的起点和终点并计算O(n)运算中的和)。并且,如果使用分段树,则可以在O(log(n))中获得总和而不是O(n)。这只能加速到O(n ^ 2 log(n)),这对N = 50000不起作用。

替代算法

我认为你应该看一下这个,它在每个查询的O(n)中运行:http://www.geeksforgeeks.org/largest-sum-contiguous-subarray/。用C / C ++编写,并像一位评论者建议的那样高效地使用IO。

答案 1 :(得分:2)

您可以尝试使用生成器,因为您可以绕过很多限制。但是,您没有提供明确显示您的性能问题的数据集 - 您能提供有问题的数据集吗?

您可以尝试:

t=[None]*2*7
inf=10**9+7

def build_generator(a, v, start, end):
    n = len(a)

    if start == end:
        t[v] = a[start]
    else:
        mid = (start + end) / 2
        next(build_generator(a, v * 2, start, mid))
        next(build_generator(a, v * 2 + 1, mid + 1, end))
        t[v] = min(t[2 * v], t[2 * v + 1])
    yield t



def range_minimum_query_generator(node,segx,segy,qx,qy):
    if qx > segy or qy < segx:
        yield inf
    elif segx >= qx and segy <= qy:
        yield t[node]
    else:
        min_ = min(
            next(range_minimum_query_generator(node*2,segx,(segx+segy)/2,qx,qy)),
            next(range_minimum_query_generator(node*2+1,((segx+segy)/2)+1,segy,qx,qy))
        )
        yield min_

next(build_generator([18,17,13,19,15,11,20],1,0,6))
value = next(range_minimum_query_generator(1, 1, 7, 1, 3))
print(value)

修改

事实上,这可能无法解决您的问题。还有另一种方法可以解决任何递归限制(如D. Beazley在其关于生成器的教程中所描述的那样 - https://www.youtube.com/watch?v=D1twn9kLmYg&t=9588s在时间码2h00附近)