Dijkstra的SPF算法中两个顶点(节点)实例之间的TypeError

时间:2019-04-15 15:18:23

标签: python algorithm shortest-path dijkstra heapq

我目前正在研究解决火车时刻表优化问题,这是我学习的一部分。在这个问题中,必须最大化效用函数,从而增加访问的(关键)车站的数量,并减少列车的使用量以及列车运行的总分钟数。

问题包括工作站(节点)和连接(边缘)。这两个数据都首先从两个CSV文件加载。然后,为每个站点(包含名称以及是否关键)和每个连接(在连接中包含站点以及彼此花费的时间)实例化类。这些站点和连接都存储在词典中。

第一步,我和我的队友们决定首先要实现Dijkstra的寻路算法版本,以便找到两个站点之间的最快路线。 BogoToBogo非常详细地介绍了如何实现Dijkstra算法的版本。我们首先决定尝试并实现他们的代码,以查看结果如何。但是,TypeError不断弹出:

TypeError:“ Vertex”和“ Vertex”的实例之间不支持“ <”

如果任何人知道导致此错误的原因,将不胜感激!

#Makes the shortest path from v.previous
def shortest(v, path):
    if v.previous:
        path.append(v.previous.get_id())
        shortest(v.previous, path)
    return

def dijkstra(aGraph, start, target):
    print('Dijkstras shortest path')
    # Set the distance for the start node to zero
    start.set_distance(0)

# Put tuple pair into the priority queue
unvisited_queue = [(v.get_distance(),v) for v in aGraph]
heapq.heapify(unvisited_queue)

while len(unvisited_queue):
    # Pops a vertex with the smallest distance
    uv = heapq.heappop(unvisited_queue)
    current = uv[1]
    current.set_visited()

    #for next in v.adjacent:
    for next in current.adjacent:
        # if visited, skip
        if next.visited:
            continue
        new_dist = current.get_distance() + current.get_weight(next)

        if new_dist < next.get_distance():
            next.set_distance(new_dist)
            next.set_previous(current)
            print('updated : current = ' + current.get_id() + ' next = ' + next.get_id() + ' new_dist = ' + next.get_distance())

        else:
            print('not updated : current = ' + current.get_id() + ' next = ' + next.get_id() + ' new_dist = ' + next.get_distance())

    # Rebuild heap
    # 1. Pop every item
    while len(unvisited_queue):
        heapq.heappop(unvisited_queue)
    # 2. Put all vertices not visited into the queue
    unvisited_queue = [(v.get_distance(),v) for v in aGraph if not v.visited]
    heapq.heapify(unvisited_queue)

if __name__ == "__main__":

# Calling the CSV loading functions in mainActivity
# These functions will also instantiate station and connections objects
load_stations(INPUT_STATIONS)
load_connections(INPUT_CONNECTIONS)

g = Graph()

for index in stations:
    g.add_vertex(stations[index].name)

for counter in connections:
    g.add_edge(connections[counter].stat1, connections[counter].stat2, int(connections[counter].time))

for v in g:
    for w in v.get_connections():
        vid = v.get_id()
        wid = w.get_id()
        print( vid, wid, v.get_weight(w))

dijkstra(g, g.get_vertex('Alkmaar'), g.get_vertex('Zaandam'))

target = g.get_vertex('Zaandam')
path = [target.get_id()]
shortest(target, path)
print('The shortest path :' + (path[::-1]))

在这种情况下,给定参数g(这是Graph类的实例),Alkmaar和Zaandam,将调用函数dijkstra。

# Represents a grid of nodes/stations composed of nodes and edges
class Graph:
    def __init__(self):
        self.vert_dict = {}
        self.num_vertices = 0

    def __iter__(self):
        return iter(self.vert_dict.values())

    def add_vertex(self, node):
        self.num_vertices = self.num_vertices + 1
        new_vertex = Vertex(node)
        self.vert_dict[node] = new_vertex
        return new_vertex

    def get_vertex(self, n):
        if n in self.vert_dict:
            return self.vert_dict[n]
        else:
            return None

    def add_edge(self, frm, to, cost = 0):
        if frm not in self.vert_dict:
            self.add_vertex(frm)
        if to not in self.vert_dict:
            self.add_vertex(to)

        self.vert_dict[frm].add_neighbor(self.vert_dict[to], cost)
        self.vert_dict[to].add_neighbor(self.vert_dict[frm], cost)

    def get_vertices(self):
        return self.vert_dict.keys()

    def set_previous(self, current):
        self.previous = current

    def get_previous(self, current):
        return self.previous

Graph类。

# Represents a node (station)
class Vertex:

    def __init__(self, node):
        self.id = node
        self.adjacent = {}
        # Set distance to infinity for all nodes
        self.distance = sys.maxsize
        # Mark all nodes unvisited
        self.visited = False
        # Predecessor
        self.previous = None

    def add_neighbor(self, neighbor, weight=0):
        self.adjacent[neighbor] = weight

    def get_connections(self):
        return self.adjacent.keys()

    def get_id(self):
        return self.id

    def get_weight(self, neighbor):
        return self.adjacent[neighbor]

    def set_distance(self, dist):
        self.distance = dist

    def get_distance(self):
        return self.distance

    def set_previous(self, prev):
        self.previous = prev

    def set_visited(self):
        self.visited = True

    def __str__(self):
        return str(self.id) + ' adjacent: ' + str([x.id for x in self.adjacent])

Vertex类。 谢谢您的宝贵时间!

1 个答案:

答案 0 :(得分:0)

我认为这可能会有所帮助,但是发布到stackoverflow的方式只是发布尽可能少的完整信息

# Put tuple pair into the priority queue
unvisited_queue = [(v.get_distance(),v) for v in aGraph]
heapq.heapify(unvisited_queue)

如果您看这段代码,它会将列表转换为需要<比较所提供内容的堆,在vertex类中定义__gt__()方法,该函数将确定获取什么首先弹出,所以请按照您认为合适的方式写出来,我认为错误会消失。 :-)