如何检测Matplotlib PathCollection中的哪个项目被点击?

时间:2018-04-19 15:36:19

标签: python matplotlib mouseevent networkx

我可以使用Matplotlib绘制有向网络图。现在我希望能够响应鼠标事件,以便用户可以与网络进行交互。例如,节点可以在用户单击时更改其颜色。这只是一个例子,但它说明了这一点。我也想知道点击了哪个节点(标签);我对它在空间中的x,y坐标并不感兴趣。

到目前为止,这是我的代码:

import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
import networkx as nx

def createDiGraph():
    G = nx.DiGraph()

    # Add nodes:
    nodes = ['A', 'B', 'C', 'D', 'E']
    G.add_nodes_from(nodes)

    # Add edges or links between the nodes:
    edges = [('A','B'), ('B','C'), ('B', 'D'), ('D', 'E')]
    G.add_edges_from(edges)
    return G

G = createDiGraph()

# Get a layout for the nodes according to some algorithm.
pos = nx.layout.spring_layout(G, random_state=779)

node_size = 300
nodes = nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color=(0,0,0.9),
                                edgecolors='black')
# nodes is a matplotlib.collections.PathCollection object
nodes.set_picker(5)
#nodes.pick('button_press_event')
#print(nodes.get_offsets())

nx.draw_networkx_edges(G, pos, node_size=node_size, arrowstyle='->',
                                arrowsize=15, edge_color='black', width=1)

nx.draw_networkx_labels(G, pos, font_color='red', font_family='arial',
                                font_size=10)

def onpick(event):
    #print(event.mouseevent)

    if isinstance(event.artist, PathCollection):
        #nodes = event.artist
        print (event.ind)

fig = plt.gcf()

# Bind our onpick() function to pick events:
fig.canvas.mpl_connect('pick_event', onpick)

# Hide the axes:
plt.gca().set_axis_off()
plt.show()

绘制时网络如下所示:

enter image description here

例如,如果我在节点C上单击鼠标,程序将打印出[1];或[3]表示节点E.注意,索引不对应于0表示A,1表示B,2表示C,依此类推,即使原始节点按顺序添加到networkx有向图中也是如此

那么当我点击节点C时,如何获得值' C'如何在图中抓住表示C的对象,以便我可以改变它的颜色?

我尝试过使用PathCollection.pick,但我不确定要传递给它的是什么,我不确定这是否是正确的使用方法。

2 个答案:

答案 0 :(得分:1)

看起来G在创建节点时不会保留节点的顺序,但节点的顺序似乎存储在pos中。我建议如下:

添加到您的import语句:

from ast import literal_eval

label中定义nx.draw_networkx_nodes并在函数update_plot中创建以Gposcolor为参数的图:

def update_plot(pos, G, colors):
    # the keys of the pos dictionary contains the labels that you are interested in
    # or label = [*pos]
    nodes = nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color=colors,edgecolors='black', label=list(pos.keys()))
    # nodes is a matplotlib.collections.PathCollection object
    nodes.set_picker(5)
    nx.draw_networkx_edges(G, pos, node_size=node_size, arrowstyle='->', arrowsize=15, edge_color='black', width=1)
    nx.draw_networkx_labels(G, pos, font_color='red', font_family='arial', font_size=10)

您的onpick功能应为:

def onpick(event):
    if isinstance(event.artist, PathCollection):
        #index of event
        ind = event.ind[0]
        #convert the label list from a string back to a list
        label_list = literal_eval(event.artist.get_label())
        print(label_list[ind])
        colors = [(0, 0, 0.9)] * len(pos)
        colors[ind]=(0.9,0,0)
        update_plot(pos, G, colors)

您的代码正文就是:

G = createDiGraph()
# Get a layout for the nodes according to some algorithm.
pos = nx.layout.spring_layout(G, random_state=779)

node_size = 300
colors=[(0,0,0.9)]*len(pos)
update_plot(pos, G, colors)

答案 1 :(得分:1)

在@apogalacticon的有用评论之后,我最终得到了这个解决方案。

import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
import networkx as nx

def createDiGraph():
    G = nx.DiGraph()

    # Add nodes:
    nodes = ['A', 'B', 'C', 'D', 'E']
    G.add_nodes_from(nodes)

    # Add edges or links between the nodes:
    edges = [('A','B'), ('B','C'), ('B', 'D'), ('D', 'E')]
    G.add_edges_from(edges)
    return G

G = createDiGraph()

# Get a layout for the nodes according to some algorithm.
pos = nx.layout.spring_layout(G, random_state=779)

node_size = 300
nodes = nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color=(0,0,0.9),
                                edgecolors='black')
# nodes is a matplotlib.collections.PathCollection object
nodes.set_picker(5)

nx.draw_networkx_edges(G, pos, node_size=node_size, arrowstyle='->',
                                arrowsize=15, edge_color='black', width=1)

nx.draw_networkx_labels(G, pos, font_color='red', font_family='arial',
                                font_size=10)

def onpick(event):

    if isinstance(event.artist, PathCollection):
        all_nodes = event.artist
        ind = event.ind[0] # event.ind is a single element array.
        this_node_name = pos.keys()[ind]
        print(this_node_name)

        # Set the colours for all the nodes, highlighting the picked node with
        # a different colour:
        colors = [(0, 0, 0.9)] * len(pos)
        colors[ind]=(0, 0.9, 0)
        all_nodes.set_facecolors(colors)

        # Update the plot to show the change:
        fig.canvas.draw() # plt.draw() also works.
        #fig.canvas.flush_events() # Not required? See https://stackoverflow.com/a/4098938/1843329

fig = plt.gcf()

# Bind our onpick() function to pick events:
fig.canvas.mpl_connect('pick_event', onpick)

# Hide the axes:
plt.gca().set_axis_off()
plt.show()

单击节点时,其面部颜色变为绿色,而其他节点设置为蓝色:

enter image description here

不幸的是,似乎网络中的节点由单个matplotlib.artist.Artist对象表示,该对象恰好是没有任何艺术家孩子的路径集合。这意味着您无法获得单个节点以更改其属性。相反,您不得不更新所有节点,只需确保所选节点的属性(在这种情况下为颜色)与其他节点不同。