Python绘制sankey并确定节点的顺序

时间:2020-11-10 09:43:01

标签: python plotly nodes sankey-diagram

我正在绘制一个sankey图,其中值从1个起始节点到5个音符(A,B,C,D和E)。我希望按字母顺序绘制节点,我认为这可以通过我的代码实现,但事实并非如此,从运行我的代码可以看到-我如何确保A后跟B和B然后是C等?

我有Python 3.9,并已将其密谋更新为版本4.12.0,但没有帮助。我在Jupyter笔记本电脑和Spyder(4.15)中都运行了代码,但是节点的顺序已关闭-您能建议我如何在代码中指定顺序吗?

import plotly.graph_objects as go
import plotly.express as px

source = [0, 0, 0, 0, 0]

target = [1, 2, 3, 4, 5]


value = [356, 16, 39, 6, 88]

label = ['Start', 'A', 'B', 'C', 'D', 'E']


color_node = ['#EBBAB5', 
'#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD']
color_link = ['#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD', 
'#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD','#EBBAB5', '#FEF3C7', '#A6E3D7','#98FB98', '#DDA0DD']

link = dict(source=source, target=target, value = value, color = color_link)
node = dict(label = label, pad=30, thickness=5, color = color_node)


data = go.Sankey(link = link, node = node)
fig = go.Figure(data)
fig.show()

1 个答案:

答案 0 :(得分:0)

您可以使用this feature

例如,虽然相当手工,但这可以达到目的:

n = 1/4
link = dict(source=source, target=target, value = value, color = color_link)
node = dict(label = label, 
            x = [0, 1, 1, 1, 1, 1],
            y = [0, 0*n, 1*n, 2*n, 3*n, 4*n],
            pad=30, 
            thickness=5, 
            color = color_node)

data = go.Sankey(
    link = link, 
    node = node,
    arrangement = "snap", 
)
fig = go.Figure(data)
fig.show()

我所做的只是指定节点的位置:x==0用于第一列,x==1 用于第二列。对于y,第二列中的节点使用0到1之间的偶数间距。我预期节点会重叠,但是似乎plotly.js中的逻辑可以为您解决这个问题。

您可以通过指定两个标签列表来以编程方式完成操作:

label_l = ['Start', ]
label_r = [ 'A', 'B', 'C', 'D', 'E']
...
node = dict(label = label_l + label_r, 
            x = [0, ]*len(label_l) + [1,]*len(label_r),
            y = list(np.linspace(0,1,len(label_l))) + list(np.linspace(0,1,len(label_r))),
            pad=30, 
            thickness=5, 
            color = color_node)

但是,我不保证它适用于更复杂的情况。