以下是我在下面的结果图,但我希望它看起来像astrodendro
中的截断树形图,例如this:
我希望在matplotlib
中重新创建this paper中非常酷的树状图。
以下是使用噪声变量生成iris
数据集并在matplotlib
中绘制树形图的代码。
有谁知道如何:(1)截断分支,如示例图所示;和/或(2)使用astrodendro
自定义链接矩阵和标签?
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import astrodendro
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial import distance
def iris_data(noise=None, palette="hls", desat=1):
# Iris dataset
X = pd.DataFrame(load_iris().data,
index = [*map(lambda x:f"iris_{x}", range(150))],
columns = [*map(lambda x: x.split(" (cm)")[0].replace(" ","_"), load_iris().feature_names)])
y = pd.Series(load_iris().target,
index = X.index,
name = "Species")
c = map_colors(y, mode=1, palette=palette, desat=desat)#y.map(lambda x:{0:"red",1:"green",2:"blue"}[x])
if noise is not None:
X_noise = pd.DataFrame(
np.random.RandomState(0).normal(size=(X.shape[0], noise)),
index=X_iris.index,
columns=[*map(lambda x:f"noise_{x}", range(noise))]
)
X = pd.concat([X, X_noise], axis=1)
return (X, y, c)
def dism2linkage(DF_dism, method="ward"):
"""
Input: A (m x m) dissimalrity Pandas DataFrame object where the diagonal is 0
Output: Hierarchical clustering encoded as a linkage matrix
Further reading:
http://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.cluster.hierarchy.linkage.html
https://pypi.python.org/pypi/fastcluster
"""
#Linkage Matrix
Ar_dist = distance.squareform(DF_dism.as_matrix())
return linkage(Ar_dist,method=method)
# Get data
X_iris_with_noise, y_iris, c_iris = iris_data(50)
# Get distance matrix
df_dism = 1- X_iris_with_noise.corr().abs()
# Get linkage matrix
Z = dism2linkage(df_dism)
#Create dendrogram
with plt.style.context("seaborn-white"):
fig, ax = plt.subplots(figsize=(13,3))
D_dendro = dendrogram(
Z,
labels=df_dism.index,
color_threshold=3.5,
count_sort = "ascending",
#link_color_func=lambda k: colors[k]
ax=ax
)
ax.set_ylabel("Distance")
答案 0 :(得分:1)
我不确定这是否真的是一个实际的答案,但是它确实允许您生成带有截断的悬挂线的树状图。技巧是正常生成图,然后处理生成的matplotlib图以重新创建线条。
我无法让您的示例在本地工作,所以我刚刚创建了一个虚拟数据集。
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
import numpy as np
a = np.random.multivariate_normal([0, 10], [[3, 1], [1, 4]], size=[5,])
b = np.random.multivariate_normal([0, 10], [[3, 1], [1, 4]], size=[5,])
X = np.concatenate((a, b),)
Z = linkage(X, 'ward')
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
dendrogram(Z, ax=ax)
生成的图是通常的长臂树状图。
现在更有趣的一点。树状图由多个LineCollection
对象组成(每种颜色一个)。为了更新这些行,我们遍历了这些行,提取了有关其组成路径的详细信息,对其进行了修改以删除到达y
为零的所有行,然后为这些修改后的路径重新创建LineCollection
。>
然后将更新的路径添加到轴,并删除原始路径。
一个棘手的部分是确定要绘制的高度而不是零。由于我们正在遍历每个树状图路径,因此我们不知道之前到达哪一点-我们基本上不知道我们在哪里。但是,我们可以利用悬挂线垂直悬挂的事实。假设在同一x
上没有行,我们可以查找给定y
的其他已知x
值,并将其用作新y
的基础计算。缺点是,为了确保我们有这个数字,我们必须预先扫描数据。
注意:如果可以在同一x
上获得树状图悬挂线,则需要包括y
并在此x上方搜索最近y 为此。
import numpy as np
from matplotlib.path import Path
from matplotlib.collections import LineCollection
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
dendrogram(Z, ax=ax);
for c in ax.collections[:]: # use [:] to get a copy, since we're adding to the same list
paths = []
for path in c.get_paths():
segments = []
y_at_x = {}
# Pre-pass over all elements, to find the lowest y value at each x value.
# we can use this to caculate where to cut our lines.
for n, seg in enumerate(path.iter_segments()):
x, y = seg[0]
# Don't store if the y is zero, or if it's higher than the current low.
if y > 0 and y < y_at_x.get(x, np.inf):
y_at_x[x] = y
for n, seg in enumerate(path.iter_segments()):
x, y = seg[0]
if y == 0:
# If we know the last y at this x, use it - 0.5, limit > 0
y = max(0, y_at_x.get(x, 0) - 0.5)
segments.append([x,y])
paths.append(segments)
lc = LineCollection(paths, colors=c.get_colors()) # Recreate a LineCollection with the same params
ax.add_collection(lc)
ax.collections.remove(c) # Remove the original LineCollection
生成的树状图如下: