我想从节点" a"获取所有LCA节点。和节点" o"。
在这个DiGraph中,节点" l"和节点" m"是LCA节点。
以下是代码。
import networkx as nx
def calc_length(Graph, node1, node2, elem):
length1 = nx.shortest_path_length(Graph, node1, elem)
length2 = nx.shortest_path_length(Graph, node1, elem)
length_sum = length1 + length2
return length_sum
G = nx.DiGraph() #Directed graph
G.add_nodes_from(["a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p"])
edges = [("a","b"),("b","c"),("b","d"),("a","e"),("a","h"),("e","f"),("e","g"),("e","i"),("h","l"),("h","m"),("g","j"),("o","p"),("o","n"),("n","m"),("n","l"),("n","k"),("p","j"),]
G.add_edges_from([(e[0], e[1]) for e in edges])
preds_1 = nx.bfs_predecessors(G, "a")
preds_2 = nx.bfs_predecessors(G, "o")
common_preds = set([n for n in preds_1]).intersection(set([n for n in preds_2]))
common_preds = list(common_preds)
dic ={}
for elem in common_preds:
length_sum = calc_length(G, "a", "o", elem)
dic[elem] = length_sum
min_num = min(dic.values())
for k, v in sorted(dic.items(), key=lambda x:x[1]):
if v != min_num:
break
else:
print k, v
我想要更快的执行速度。
如果你有一个比前面提到的方法更好的方法来解决问题,请告诉我。
我很感激你的帮助。
答案 0 :(得分:1)
这里有几个问题,其中一些我在评论中指出。问题的一部分是命名法令人困惑:最低共同祖先(as defined on wikipedia并且大概在计算机科学中)应该被命名为最低共同后代,以符合networkx所使用的命名法(以及任何理智)我知道的网络科学家)。因此,你的广度优先搜索应该真正跟随后代,而不是前辈。以下实现了这样的LCA搜索:
import numpy as np
import matplotlib.pyplot as plt; plt.ion()
import networkx as nx
def find_lowest_common_ancestor(graph, a, b):
"""
Find the lowest common ancestor in the directed, acyclic graph of node a and b.
The LCA is defined as on
@reference:
https://en.wikipedia.org/wiki/Lowest_common_ancestor
Notes:
------
This definition is the opposite of the term as it is used e.g. in biology!
Arguments:
----------
graph: networkx.DiGraph instance
directed, acyclic, graph
a, b:
node IDs
Returns:
--------
lca: [node 1, ..., node n]
list of lowest common ancestor nodes (can be more than one)
"""
assert nx.is_directed_acyclic_graph(graph), "Graph has to be acyclic and directed."
# get ancestors of both (intersection)
common_ancestors = list(nx.descendants(graph, a) & nx.descendants(graph, b))
# get sum of path lengths
sum_of_path_lengths = np.zeros((len(common_ancestors)))
for ii, c in enumerate(common_ancestors):
sum_of_path_lengths[ii] = nx.shortest_path_length(graph, a, c) \
+ nx.shortest_path_length(graph, b, c)
# print common_ancestors
# print sum_of_path_lengths
# return minima
minima, = np.where(sum_of_path_lengths == np.min(sum_of_path_lengths))
return [common_ancestors[ii] for ii in minima]
def test():
nodes = ["a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p"]
edges = [("a","b"),
("b","c"),
("b","d"),
("a","e"),
("a","h"),
("e","f"),
("e","g"),
("e","i"),
("h","l"),
("h","m"),
("g","j"),
("o","p"),
("o","n"),
("n","m"),
("n","l"),
("n","k"),
("p","j"),]
G = nx.DiGraph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)
# plot
pos = nx.spring_layout(G)
nx.draw(G, pos)
nx.draw_networkx_labels(G, pos, labels=dict([(c, c) for c in 'abcdefghijklmnop']))
plt.show()
a,b = 'a','o'
lca = find_lowest_common_ancestor(G, a, b)
print "Lowest common ancestor(s) for {} and {}: {}".format(a, b, lca)
return
if __name__ == "__main__":
test()