Center Seaborn Colorbar标签

时间:2018-01-24 22:48:14

标签: python matplotlib seaborn

我正在查看离散数据,但是seaborn中的颜色条似乎只是为连续变量设置。

我的代码会生成我想要的图表,但颜色条上的标签不会与各自的颜色对齐。

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

df = pd.DataFrame(np.random.randint(0,6,size=(52, 7)))
colors = {0:'#05926f', 1:'#99cc33', 2:'#F26419', 3:'#F6AE2D', 4:'#06AED5', 5:'#3baa5d'}
cmap = mpl.colors.ListedColormap(list(colors.values()))
fig, ax = plt.subplots(figsize=(3,8))
ax = sns.heatmap(df, annot=True, cmap=cmap, linewidths=.05)
plt.plot()

我尝试将cbar_kws={"anchor": (0.5, 0.5)}作为参数添加到sns.heatmap调用中,但是它会抛出一个错误,说它不喜欢anchor关键字。但是seaborn documentation说我应该可以使用来自pyplot.colorbar的参数,但没有运气。没关系,请参阅下面的评论。

我不确定如何将该标签居中。我错过了一些明显的东西吗?

我还希望将标签从#更改为Group-#,如果这是一个简单的解决方案。

谢谢!

1 个答案:

答案 0 :(得分:1)

您可能想要使用边界规范。问题是matplotlib不知道你想要哪个值对应哪种颜色。要提供此信息,可以使用BoundaryNorm,指定颜色的bin边缘。在此示例中,您将整数0,1,2,3,4,5作为值,您的bin边缘最好选择为-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5,以使值位于bin的中心。

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
#%matplotlib inline

df = pd.DataFrame(np.random.randint(0,6,size=(52, 7)))
colors = {0:'#05926f', 1:'#99cc33', 2:'#F26419', 3:'#F6AE2D', 4:'#06AED5', 5:'#3baa5d'}

colorrange = range(len(list(colors.keys())))
colorlist = [colors[i] for i in colorrange]
cmap = mpl.colors.ListedColormap(colorlist)
bounds = np.array(range(len(list(colors.keys()))+1))-0.5
norm = mpl.colors.BoundaryNorm(bounds, len(colorrange))

fig, ax = plt.subplots(figsize=(4.2,8))
fig.subplots_adjust(right=0.8)
ax = sns.heatmap(df, annot=True, cmap=cmap, norm=norm, 
                 cbar_kws={'format': 'Group-%g'}, linewidths=.05)
plt.show()

enter image description here

数据范围未从0开始但仍为N个连续整数的代码:

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
#%matplotlib inline

df = pd.DataFrame(np.random.randint(1,7,size=(52, 7)))
colors = {6:'#05926f', 1:'#99cc33', 2:'#F26419', 3:'#F6AE2D', 4:'#06AED5', 5:'#3baa5d'}

colorrange = sorted(list(colors.keys()))
colorlist = [colors[i] for i in colorrange]
cmap = mpl.colors.ListedColormap(colorlist)
bounds = np.array(colorrange+[max(colorrange)+1])-0.5
norm = mpl.colors.BoundaryNorm(bounds, len(colorrange))

fig, ax = plt.subplots(figsize=(4.2,8))
fig.subplots_adjust(right=0.8)
ax = sns.heatmap(df, annot=True, cmap=cmap, norm=norm, 
                 cbar_kws={'format': 'Group-%g'}, linewidths=.05)
plt.show()