在python

时间:2018-03-16 10:58:06

标签: python-3.x networkx

g=nx.DiGraph(directed=True)
g.add_nodes_from(o)
for j in range (len(o)):
   for i in range(len(ixx)):
      g.add_edge(ixx[i],o[j-1])
   g.add_edge(o[j-1],WIN[0], weight=10)   
nx.draw(g,with_labels=True)
plt.draw()
plt.show()

这是图表的代码。 ixx是输入节点,WIN是单输出节点,o是隐藏节点。<​​/ p>

示例网络看起来像这样。 (这是我运行代码时得到的) Numbers from 1 to ...26 are hidden nodes. 27 is output node

但是,我想绘制它:左侧输入节点,中间隐藏输出节点,右侧输出节点。就像神经网络的外观一样。

1 个答案:

答案 0 :(得分:0)

以下是如何做到这一点的演示:

import networkx as nx
import matplotlib.pyplot as plt
from random import sample, seed 

seed(0)
G=nx.dodecahedral_graph()

# splitting the graph to the sets of nodes
samples = sample(G.nodes,2)
output_set = set(samples[:1])
input_set = set(samples[1:2])
hidden_set = set(G.nodes) - output_set - input_set

# a function for shifting node positions
def make_spring_pos(nodes,shift):
  return { node:(pos[0]+shift,pos[1]) for node,pos in nx.spring_layout(G.subgraph(nodes)).items() }

# shifting node positions for different sets
input_pos= make_spring_pos(input_set, shift=0)
hidden_pos= make_spring_pos(hidden_set, shift=3)
output_pos= make_spring_pos(output_set, shift=6)

# merging positions back
all_pos = {}
for pos in [input_pos,hidden_pos,output_pos]:
  all_pos.update(pos)

nx.draw_networkx(G, pos=all_pos)

plt.show()

output