我正在研究 3D 点云。我有点云图结构的稀疏矩阵表示(如 scipy.sparse 中的 csr_matrix)。我想将测地距离(由图中的路径长度近似)的某个阈值内的点聚集在一起,并将它们一起处理。为了找到这样的点,我需要运行一些像 Dijkstra 这样的最短路径查找算法。简而言之,我的想法是这样的
这将用于我的转发功能。 有没有办法在我的功能中实现 Dijkstra 的功能?
或者我可以实施的任何其他想法?
非常感谢!
答案 0 :(得分:0)
我使用优先级队列为 Dijkstra 创建了自定义实现here
同样,我使用如下的火炬函数创建了一个自定义 PriorityQ
类
class priorityQ_torch(object):
"""Priority Q implelmentation in PyTorch
Args:
object ([torch.Tensor]): [The Queue to work on]
"""
def __init__(self, val):
self.q = torch.tensor([[val, 0]])
# self.top = self.q[0]
# self.isEmpty = self.q.shape[0] == 0
def push(self, x):
"""Pushes x to q based on weightvalue in x. Maintains ascending order
Args:
q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]
x ([torch.Tensor]): [[index, weight] tensor to be inserted]
Returns:
[torch.Tensor]: [The queue tensor after correct insertion]
"""
if type(x) == np.ndarray:
x = torch.tensor(x)
if self.isEmpty():
self.q = x
self.q = torch.unsqueeze(self.q, dim=0)
return
idx = torch.searchsorted(self.q.T[1], x[1])
print(idx)
self.q = torch.vstack([self.q[0:idx], x, self.q[idx:]]).contiguous()
def top(self):
"""Returns the top element from the queue
Returns:
[torch.Tensor]: [top element]
"""
return self.q[0]
def pop(self):
"""pops(without return) the highest priority element with the minimum weight
Args:
q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]
Returns:
[torch.Tensor]: [highest priority element]
"""
if self.isEmpty():
print("Can Not Pop")
self.q = self.q[1:]
def isEmpty(self):
"""Checks is the priority queue is empty
Args:
q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]
Returns:
[Bool] : [Returns True is empty]
"""
return self.q.shape[0] == 0
现在是dijkstra,带有邻接矩阵(以图权重作为输入)
def dijkstra(adj):
n = adj.shape[0]
distance_matrix = torch.zeros([n, n])
for i in range(n):
u = torch.zeros(n, dtype=torch.bool)
d = np.inf * torch.ones(n)
d[i] = 0
q = priorityQ_torch(i)
while not q.isEmpty():
v, d_v = q.top() # point and distance
v = v.int()
q.pop()
if d_v != d[v]:
continue
for j, py in enumerate(adj[v]):
if py == 0 and j != v:
continue
else:
to = j
weight = py
if d[v] + py < d[to]:
d[to] = d[v] + py
q.push(torch.Tensor([to, d[to]]))
distance_matrix[i] = d
return distance_matrix
返回图形点的最短路径距离矩阵!