Python OOP不相交集性能

时间:2017-04-26 16:40:23

标签: python algorithm performance kruskals-algorithm disjoint-sets

我构建了一个不相交的数据结构,用于Kruskal的MST算法。我需要加载,然后将图表与200k互连节点联合起来,我认为我的数据结构实现是一个瓶颈。

您对如何提高性能有什么建议吗?我认为我的find方法可能存在问题。

class partition(object):
    def __init__(self, element=None):
        self.size = 0
        if element == None:
            self.contents = set()
            self.representative = None
        else:
            self.contents = {element}
            self.representative = element
            self.size = 1

    def find(self, element):
        return element in self.contents

    def add(self, partition):
        self.contents = self.contents.union(partition)
        self.size = len(self.contents)

    def show(self):
        return self.contents

    def __repr__(self):
        return str(self.contents)

class disjoint_set(object):
    def __init__(self):
        self.partitions_count = 0
        self.forest = {}

    def make_set(self, element):
        if self.find(element) == False:
            new_partition = partition(element)
            self.forest[new_partition.representative] = new_partition
            self.partitions_count += 1

    def union(self, x, y):
        if x != y:
            if self.forest[x].size < self.forest[y].size:
                self.forest[y].add(self.forest[x].show())
                self.delete(x)
            else:
                self.forest[x].add(self.forest[y].show())
                self.delete(y)

    def find(self, element):
        for partition in self.forest.keys():
            if self.forest[partition].find(element):
                return self.forest[partition].representative
        return False

    def delete(self, partition):
        del self.forest[partition]
        self.partitions_count -= 1

if __name__ == '__main__':
    t = disjoint_set()
    t.make_set(1)
    t.make_set(2)
    t.make_set(3)
    print("Create 3 singleton partitions:")
    print(t.partitions_count)
    print(t.forest)
    print("Union two into a single partition:")
    t.union(1,2)
    print(t.forest)
    print(t.partitions_count)

编辑:

在阅读评论并进行额外研究后,我意识到我的原始算法设计得很差。我从零开始,把它放在一起。我将所有分区放入一个哈希表中,并在find()中使用路径压缩。这看起来怎么样,我应该解决哪些明显的问题?

class disjoint_set(object):
def __init__(self):
    self.partitions_count = 0
    self.size = {}
    self.parent = {}

def make_set(self, element):
    if self.find(element) == False:
        self.parent[element] = element
        self.size[element] = 1
        self.partitions_count += 1

def union(self, x, y):
    xParent = self.find(x)
    yParent = self.find(y)
    if xParent != yParent:
        if self.size[xParent] < self.size[yParent]:
            self.parent[xParent] = yParent
            self.size[yParent] += self.size[xParent]
            self.partitions_count -= 1
        else:
            self.parent[yParent] = xParent
            self.size[xParent] += self.size[yParent]
            self.partitions_count -= 1

def find(self, element):
    if element in self.parent:
        if element == self.parent[element]:
            return element
        root = self.parent[element]
        while self.parent[root] != root:
            root = self.find(self.parent[root])
        self.parent[element] = root
        return root
    return False

if __name__ == '__main__':
    t = disjoint_set()
    t.make_set(1)
    t.make_set(2)
    t.make_set(3)
    t.make_set(4)
    t.make_set(5)
    print("Create 5 singleton partitions")
    print(t.partitions_count)
    print("Union two singletons into a single partition")
    t.union(1,2)
    print("Union three singletones into a single partition")
    t.union(3,4)
    t.union(5,4)
    print("Union a single partition")
    t.union(2,4)
    print("Parent List: %s" % t.parent)
    print("Partition Count: %s" % t.partitions_count)
    print("Parent of element 2: %s" % t.find(2))
    print("Parent List: %s" % t.parent)

1 个答案:

答案 0 :(得分:0)

我猜你的find实现并没有运行,它本应该是。

以下更改可能有所帮助。

class disjoint_set(object):
    def __init__(self):
        self.partitions_count = 0
        self.forest = {}
        self.parent = {}

    def make_set(self, element):
        if not self.find(element):
            new_partition = partition(element)
            self.parent[element] = element
            self.forest[new_partition.representative] = new_partition
            self.partitions_count += 1

def union(self, x, y):
    if x != y:
        if self.forest[x].size < self.forest[y].size:
            self.forest[y].add(self.forest[x].show())
            #Update parent details 
            self.parent[self.forest[x].representative] = self.forest[y].representative
            self.delete(x)
        else:
            self.forest[x].add(self.forest[y].show())
            #Update parent details 
            self.parent[self.forest[y].representative] = self.forest[x].representative
            self.delete(y)

def find(self, element):
    if self.parent[element] == element:
        return element
    else:
        return find(element)

可以使用路径压缩优化代码,使disjoint_set.find在O(1)中运行。我猜O(log n)仍然适合大数字。

另一个瓶颈可能是你的工会职能。特别是添加功能实现。

def add(self, partition):
    self.contents = self.contents.union(partition)

尝试使用set的更新方法(这是一个inplace union)。我认为这会给巨大的节点带来大量的内存开销。像

这样的东西
self.contents.update(partition)

关于set union和update functions here的讨论非常好。

希望它有所帮助!