火炬:Dijkstra 算法

时间:2021-03-24 14:15:50

标签: python machine-learning deep-learning pytorch

我正在研究 3D 点云。我有点云图结构的稀疏矩阵表示(如 scipy.sparse 中的 csr_matrix)。我想将测地距离(由图中的路径长度近似)的某个阈值内的点聚集在一起,并将它们一起处理。为了找到这样的点,我需要运行一些像 Dijkstra 这样的最短路径查找算法。简而言之,我的想法是这样的

  1. 从 N 个点中采样 K 个点(我可以使用最远点采样来做)
  2. 为 K 个点中的每一个找到最近的测地线邻居(使用支持 BackProp 的算法)
  3. 使用一些神经网络处理每个点的邻居

这将用于我的转发功能。 有没有办法在我的功能中实现 Dijkstra 的功能?

或者我可以实施的任何其他想法?

非常感谢!

1 个答案:

答案 0 :(得分:0)

我使用优先级队列为 Dijkstra 创建了自定义实现here 同样,我使用如下的火炬函数创建了一个自定义 PriorityQ

class priorityQ_torch(object):
    """Priority Q implelmentation in PyTorch

    Args:
        object ([torch.Tensor]): [The Queue to work on]
    """

    def __init__(self, val):
        self.q = torch.tensor([[val, 0]])
        # self.top = self.q[0]
        # self.isEmpty = self.q.shape[0] == 0

    def push(self, x):
        """Pushes x to q based on weightvalue in x. Maintains ascending order

        Args:
            q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]
            x ([torch.Tensor]): [[index, weight] tensor to be inserted]

        Returns:
            [torch.Tensor]: [The queue tensor after correct insertion]
        """
        if type(x) == np.ndarray:
            x = torch.tensor(x)
        if self.isEmpty():
            self.q = x
            self.q = torch.unsqueeze(self.q, dim=0)
            return
        idx = torch.searchsorted(self.q.T[1], x[1])
        print(idx)
        self.q = torch.vstack([self.q[0:idx], x, self.q[idx:]]).contiguous()

    def top(self):
        """Returns the top element from the queue

        Returns:
            [torch.Tensor]: [top element]
        """
        return self.q[0]

    def pop(self):
        """pops(without return) the highest priority element with the minimum weight

        Args:
            q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]

        Returns:
            [torch.Tensor]: [highest priority element]
        """
        if self.isEmpty():
            print("Can Not Pop")
        self.q = self.q[1:]

    def isEmpty(self):
        """Checks is the priority queue is empty

        Args:
            q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]

        Returns:
            [Bool] : [Returns True is empty]
        """
        return self.q.shape[0] == 0

现在是dijkstra,带有邻接矩阵(以图权重作为输入)

def dijkstra(adj):
    n = adj.shape[0]
    distance_matrix = torch.zeros([n, n])
    for i in range(n):
        u = torch.zeros(n, dtype=torch.bool)
        d = np.inf * torch.ones(n)
        d[i] = 0
        q = priorityQ_torch(i)
        while not q.isEmpty():
            v, d_v = q.top()  # point and distance
            v = v.int()
            q.pop()
            if d_v != d[v]:
                continue
            for j, py in enumerate(adj[v]):
                if py == 0 and j != v:
                    continue
                else:
                    to = j
                    weight = py
                    if d[v] + py < d[to]:
                        d[to] = d[v] + py
                        q.push(torch.Tensor([to, d[to]]))
        distance_matrix[i] = d
    return distance_matrix

返回图形点的最短路径距离矩阵!