根据相似度(例如,余弦相似度等)对Holoviews热图的列和行进行重新排序

时间:2019-03-22 11:41:16

标签: sorting dataframe seaborn heatmap holoviews

我很惊讶以前似乎没有人问过这个问题。

假设我有一个熊猫数据框(随机示例),我可以使用Holoviews和Bokeh渲染器获得一个热图:

rownames = 'ABCDEFGHIJKLMNO'
df = pd.DataFrame(np.random.randint(0,20,size=(20, len(rownames))), columns=list(rownames))
hv.HeatMap({'x': df.columns, 'y': df.index, 'z': df}, 
           kdims=[('x', 'Col Categories'), ('y', 'Row Categories')], 
           vdims='z').opts(cmap="viridis", width=520, height=520)

enter image description here

数据(x和y)是分类的,因此行或列的初始顺序并不重要。我想根据一些相似度对行/列进行排序。

一种方法是使用seaborn clustermap:

heatmap_sns = sns.clustermap(df, metric="cosine", standard_scale=1, method="ward", cmap="viridis")

输出看起来像这样: enter image description here

已根据相似性对列和行进行了排序(在这种情况下,余弦基于点积;其他可用,例如“相关”等)。

但是,我想在Holoviews中显示集群图。如何更新Seaborn矩阵中原始数据帧的顺序?

2 个答案:

答案 0 :(得分:1)

可以使用以下方法从seaborn簇图中访问重新排序的列/行的索引:

> print(f'rows: {heatmap_sns.dendrogram_row.reordered_ind}')
> print(f'columns: {heatmap_sns.dendrogram_col.reordered_ind}')
rows: [5, 0, 13, 2, 18, 7, 4, 16, 12, 19, 14, 15, 10, 3, 8, 6, 17, 11, 1, 9]
columns: [7, 1, 10, 5, 9, 0, 8, 13, 2, 6, 14, 3, 4, 11, 12]

要更新原始数据框的行/列顺序:

# get col and row names by ID
colname_list = [df.columns[col_id] for col_id in heatmap_sns.dendrogram_col.reordered_ind]
rowname_list = [df.index[row_id] for row_id in heatmap_sns.dendrogram_row.reordered_ind]
# update dataframe
df_ro = df.reindex(rowname_list)
df_ro = df_ro[colname_list]

在这里,我首先获得名称,也许甚至有直接方法通过索引更新列/行。

hv.HeatMap({'x': df_ro.columns, 'y': df_ro.index, 'z': df_ro}, 
           kdims=[('x', 'Col Categories'), ('y', 'Row Categories')], 
           vdims='z').opts(cmap="viridis", width=520, height=520)

由于我使用的是随机数据,因此类别中的顺序很少,但图片看上去的噪点却少了一些。请注意,与seaborn clustermap-matrix相比,holoviews / df y轴完全相反,这就是图形看起来翻转的原因。

enter image description here

答案 1 :(得分:1)

一种更清晰的方法来解决Alex的问题(即早先被接受的答案)是使用library(dplyr) df %>% mutate(District = if_else(District == 2, 1, District)) %>% group_by(District) %>% summarise(col_to_sum = sum(col_to_sum)) 函数返回的对象的data2d属性。此属性包含重新排序的数据(即,聚类后的数据)。所以:

sns.clustermap()

替换以下所有行:

df_ro = heatmap_sns.data2d