调整热图的颜色间隔并消除色条刻度

时间:2020-07-13 12:58:12

标签: python seaborn heatmap colorbar

我正在尝试使用以下代码进行热图绘制:


breast_cancer = load_breast_cancer()
data = breast_cancer.data
features = breast_cancer.feature_names
df = pd.DataFrame(data, columns = features)
df_small = df.iloc[:,:6]
correlation_mat = df_small.corr()

#Create color pallete: 

def NonLinCdict(steps, hexcol_array):
    cdict = {'red': (), 'green': (), 'blue': ()}
    for s, hexcol in zip(steps, hexcol_array):
        rgb =matplotlib.colors.hex2color(hexcol)
        cdict['red'] = cdict['red'] + ((s, rgb[0], rgb[0]),)
        cdict['green'] = cdict['green'] + ((s, rgb[1], rgb[1]),)
        cdict['blue'] = cdict['blue'] + ((s, rgb[2], rgb[2]),)
    return cdict
 
#https://www.december.com/html/spec/colorshades.html

hc = ['#e5e5ff', '#acacdf', '#7272bf', '#39399f', '#000080','#344152']
th = [0, 0.2, 0.4, 0.6, 0.8,1]

cdict = NonLinCdict(th, hc)
cm = matplotlib.colors.LinearSegmentedColormap('test', cdict)


#plot correlation matrix:

plt.figure(figsize = (12,10))

ax=sns.heatmap(correlation_mat,center=0, linewidths=.2, annot = True,cmap=cm , vmin=-1, vmax=1,cbar=True)

plt.title("title", y=-1.5,fontsize = 18)

plt.xlabel("X_parameters",fontsize = 18)

plt.ylabel("Y_paramaters",fontsize = 18)

ax.tick_params(axis='both', which='both', length=0)

#choose colors 
#change ticks size and remove colorbar ticks
#ad saving option
#change to 5 portions instead of four (0.2,0.4,0.6,0.8)

plt.show()

我有两个未解决的问题: 1-如何删除颜色条刻度? 2-如果vmax = 1和v min =-,如何将颜色和颜色条的间隔设置为(-1,-0.8,-0.6,-0.4,-0.2,0,0.2,0.6,0.8,1) 1)。

这是我当前在所附照片中显示的内容。Output

1 个答案:

答案 0 :(得分:0)

您可以通过ax.collections[0].colorbar抓取颜色栏。从那里可以更改刻度属性。

这是一个最小的示例:

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def NonLinCdict(steps, hexcol_array):
    cdict = {'red': (), 'green': (), 'blue': ()}
    for s, hexcol in zip(steps, hexcol_array):
        rgb = matplotlib.colors.hex2color(hexcol)
        cdict['red'] = cdict['red'] + ((s, rgb[0], rgb[0]),)
        cdict['green'] = cdict['green'] + ((s, rgb[1], rgb[1]),)
        cdict['blue'] = cdict['blue'] + ((s, rgb[2], rgb[2]),)
    return cdict

hc = ['#e5e5ff', '#acacdf', '#7272bf', '#39399f', '#000080', '#344152']
hc = hc[:0:-1] + hc  # prepend a reversed copy, but without repeating the central value
cdict = NonLinCdict(np.linspace(0, 1, len(hc)), hc)
cm = matplotlib.colors.LinearSegmentedColormap('test', cdict)

fig = plt.figure(figsize=(8, 6))
ax = sns.heatmap(np.random.uniform(-1, 1, (10, 10)), center=0, linewidths=.2, annot=True, fmt='.2f', cmap=cm, vmin=-1, vmax=1, cbar=True)
ax.tick_params(axis='both', which='both', length=0)

cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=15, axis='both', which='both', length=0)
cbar.set_ticks(np.linspace(-1, 1, 11))
# cbar.set_label('correlation')

plt.show()

example plot