分类分类概率图

时间:2019-09-20 13:01:15

标签: python numpy matplotlib pytorch networkx

采用通常从神经网络输出的分类分类概率并使用诸如networkx之类的图形将其绘制为图表的最佳方法是什么。

给出8个类别,例如['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'],并对5个形状为(8, 5)的样本进行预测,我想绘制一个图,其中节点标签为类别,所有索引的概率为大于0.0。

array([[0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.17817138, 0.11233618, 0.12554741, 0.16154018, 0.16248149],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.70687366, 0.86215913, 0.85997397, 0.8285762 , 0.828603  ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ]],
      dtype=float32)

有效边位于数组中连续列的所有非零值之间。

1 个答案:

答案 0 :(得分:0)

如果我理解正确,那么您想要一个图形,其中顶点表示类,而边表示对于同一数据实例具有非零概率的两个类。

从与您的数据相似的数据开始(因为您的数据只有两个具有非零概率的类)...

import numpy as np

A = np.array([[0.        , 0.        , 0.        , 0.8285762 , 0.        ],
              [0.        , 0.        , 0.        , 0.        , 0.        ],
              [0.17817138, 0.        , 0.12554741, 0.        , 0.16248149],
              [0.        , 0.        , 0.        , 0.10000000, 0.        ],
              [0.70687366, 0.        , 0.85997397, 0.06154018, 0.828603  ],
              [0.        , 0.11233618, 0.        , 0.        , 0.        ],
              [0.11495425, 0.        , 0.        , 0.        , 0.        ],
              [0.        , 0.86215913, 0.        , 0.        , 0.        ]
             ])

这里可能会有一个狡猾的NumPy动作,但是通过遍历类,我们可以构建边缘列表:

from itertools import combinations

edge_list = []
idx = np.argwhere(A.T)
for i, _ in enumerate(A):
    this = idx[idx[:, 0]==i]
    combs = combinations(this[:, 1], r=2)
    edge_list.extend(list(combs))
edge_list

您可以实例化networkx.Graph并添加边,如下所示:

import networkx as nx

G = nx.Graph()
G.add_edges_from(edge_list)

然后您可以使用nx.draw(G)绘制图形,但是默认绘制效果不佳。我们可以准备节点大小和标签...

nx.set_node_attributes(G, {i:chr(65+i) for i in range(len(A))}, 'name')
nx.set_node_attributes(G, {i:sum(a) for i, a in enumerate(A)}, 'prob')

probs = nx.get_node_attributes(G, 'prob')
sizes = [1000*x for x in probs.values()]
labels = nx.get_node_attributes(G, 'name')

然后做一个更好的情节:

%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 6))

pos = nx.spring_layout(G)
nx.draw_networkx_nodes(G, pos, ax=ax, node_size=sizes, node_color='orange')
nx.draw_networkx_edges(G, pos, ax=ax, width=4, splines='curved')
nx.draw_networkx_labels(G, pos,
                        labels=labels,
                        font_size=20,
                        font_family='sans-serif',
                        font_color='blue')

plt.axis('off')
plt.show()

A plot of the resulting graph.