调整colorbar的位置并均衡子图的大小

时间:2014-09-26 10:36:53

标签: python numpy matplotlib plot colorbar

在我的previous question没有得到任何答案之后,我尝试解决了将 colorbar 而非图例添加到我的图中的问题。我还有几个问题无法解决。 的更新

  1. 我想将彩条移动到情节右侧的正确位置。
  2. 我使用相同的指令生成两个图,但第二个图看起来完全不同,我无法理解导致此问题的原因。
  3. 这是我的代码:

    import numpy as np
    import pylab as plt
    from matplotlib import rc,rcParams
    rc('text',usetex=True)
    rcParams.update({'font.size':10})
    import matplotlib.cm as cm
    from matplotlib.ticker import NullFormatter
    import matplotlib as mpl
    
    def plot(Z_s,CWL,filter_id,spectral_type,model_mag,mag,plot_name):
          f= ['U38','B','V','R','I','MB420','MB464','MB485','MB518','MB571','MB604','MB646','MB696','MB753','MB815','MB856','MB914']
          wavetable=CWL/(1+Z_s)
          dd=model_mag-mag 
          nplist=['E', 'Sbc', 'Scd', 'Irr', 'SB3', 'SB2']
          minimum,maximum=(0.,16.)
          Z = [[0,0],[0,0]]
          levels = list(np.linspace(0, 1, len(f)))
          NUM_COLORS = len(f)
          cm = plt.get_cmap('gist_rainbow')
          mycolor=[]
          for i in range(NUM_COLORS):
              mycolor.append( cm(1.*i/NUM_COLORS))  # color will now be an RGBA tuple
          mymap = mpl.colors.LinearSegmentedColormap.from_list('mycolors',mycolor)
          CS3 = plt.contourf(Z, levels, cmap=mymap)
          plt.clf()
          FILTER=filter_id
          SED=spectral_type
          for (j,d) in enumerate(nplist):
              bf=(SED==j)   
              if (j<3):
                 k=j
                 i_subplot = k + 1
                 fig = plt.figure(1, figsize=(5,5))
                 ax = fig.add_subplot(3,1,i_subplot)
                 for i in range(len(f)):
                     bb=np.where(FILTER[bf]==i)[0]
                     r=mycolor[i][0]
                     g=mycolor[i][1]
                     b=mycolor[i][2]
                     ax.scatter(wavetable[bb], dd[bb],  s=1, color=(r,g,b))
                 if (k<2):
                    ax.xaxis.set_major_formatter( NullFormatter() )
                    ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
                 else:
                    ax.set_xlabel(r'WL($\AA$)',fontsize=10)
                    ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
                 fig.subplots_adjust(wspace=0,hspace=0)
                 ax.axhline(y=0,color='k')
                 ax.set_xlim(1000,9000)
                 ax.set_ylim(-3,3)
                 ax.set_xticks(np.linspace(1000, 9000, 16, endpoint=False))
                 ax.set_yticks(np.linspace(-3, 3, 4, endpoint=False)) 
                 ax.text(8500,2.1,nplist[j], {'color': 'k', 'fontsize': 10})
                 fontsize=8
                 for tick in ax.xaxis.get_major_ticks():
                     tick.label1.set_fontsize(fontsize)
                 for tick in ax.yaxis.get_major_ticks():
                     tick.label1.set_fontsize(fontsize)
                 if (j==2):
                cbar_ax = fig.add_axes([0.9, 0.15, 0.05, 0.7])
                    cbar=plt.colorbar(CS3, cax=cbar_ax, ticks=range(0,len(f)),orientation='vertical')
                    cbar.ax.get_yaxis().set_ticks([])
                    for s, lab in enumerate(f):
                        cbar.ax.text( 0.08,(0.95-0.01)/float(len(f)-1) * s, lab, fontsize=8,ha='left')
                    fname = plot_name+'.'+nplist[0]+'.'+nplist[1]+'.'+nplist[2]+'.pdf'        
                    plt.savefig(fname)
                    plt.close()
              else:
                 k=j-3
                 i_subplot = k + 1
                 fig = plt.figure(1, figsize=(5,5))
                 ax = fig.add_subplot(3,1,i_subplot)
                 for i in range(len(f)):
                     bb=np.where(FILTER[bf]==i)[0]
                     r=mycolor[i][0]
                     g=mycolor[i][1]
                     b=mycolor[i][2]
                     ax.scatter(wavetable[bb], dd[bb],  s=1, color=(r,g,b))
                 if (k<2):
                    ax.xaxis.set_major_formatter( NullFormatter() )
                    ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
                 else:
                    ax.set_xlabel(r'WL($\AA$)',fontsize=10)
                    ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
                 fig.subplots_adjust(wspace=0,hspace=0)
                 ax.axhline(y=0,color='k')
                 ax.set_xlim(1000,9000)
                 ax.set_ylim(-3,3)
                 ax.set_xticks(np.linspace(1000, 9000, 16, endpoint=False))
                 ax.set_yticks(np.linspace(-3, 3, 4, endpoint=False)) 
                 ax.text(8500,2.1,nplist[j], {'color': 'k', 'fontsize': 10})
                 fontsize=8
                 for tick in ax.xaxis.get_major_ticks():
                     tick.label1.set_fontsize(fontsize)
                 for tick in ax.yaxis.get_major_ticks():
                     tick.label1.set_fontsize(fontsize)
                 if (j==5):
                    cbar_ax = fig.add_axes([0.9, 0.15, 0.05, 0.7])
                    cbar=plt.colorbar(CS3, cax=cbar_ax, ticks=range(0,len(f)),orientation='vertical')
                    cbar.ax.get_yaxis().set_ticks([])
                    for s, lab in enumerate(f):
                        cbar.ax.text( 0.08,(0.95-0.01)/float(len(f)-1) * s, lab , fontsize=8,ha='left')
                    fname = plot_name+'.'+nplist[3]+'.'+nplist[4]+'.'+nplist[5]+'.pdf'        
                    plt.savefig(fname)
                    plt.close()
    
    
    a=np.loadtxt('calibration.photometry.information.capak.cat')
    Z_s=a[:,0]
    CWL=a[:,1]
    filter_id=a[:,2]
    spectral_type=a[:,3]
    model_mag=a[:,4]
    mag=a[:,5]
    plot_name='test'
    plot(Z_s,CWL,filter_id,spectral_type,model_mag,mag,plot_name)
    

    您也可以从here下载数据。 我将非常感谢您的帮助。

1 个答案:

答案 0 :(得分:5)

您可以使用plt.subplots()传递gridspec_kw参数以非常灵活的方式调整轴的纵横比,然后选择顶部轴以包含颜色条。

我已经研究过你的代码,简化了它。此外,我在代码中更改了很多内容,例如:PEP8,删除了对plt.savefig()ax方法的重复调用。结果是:

import numpy as np
import pylab as plt
from matplotlib import rc, rcParams, colors

rc('text', usetex=True)
rcParams['font.size'] = 10
rcParams['axes.labelsize'] = 8

def plot(Z_s, CWL, filter_id, spectral_type, model_mag, mag, plot_name):
    f= ['U38', 'B', 'V', 'R', 'I', 'MB420', 'MB464', 'MB485', 'MB518',
        'MB571', 'MB604', 'MB646', 'MB696', 'B753', 'MB815', 'MB856',
        'MB914']
    wavetable = CWL/(1+Z_s)
    dd = model_mag-mag
    nplist = ['E', 'Sbc', 'Scd', 'Irr', 'SB3', 'SB2']
    minimum, maximum = (0., 16.)
    Z = [[0, 0],[0, 0]]
    levels = list(np.linspace(0, 1, len(f)+1))
    NUM_COLORS = len(f)
    cmap = plt.get_cmap('gist_rainbow')
    mycolor = []
    for i in range(NUM_COLORS):
         mycolor.append(cmap(1.*i/NUM_COLORS))
    mymap = colors.LinearSegmentedColormap.from_list('mycolors', mycolor)
    CS3 = plt.contourf(Z, levels, cmap=mymap)
    coords = CS3.get_array()
    coords = coords[:-1] + np.diff(coords)/2.
    FILTER = filter_id
    SED = spectral_type
    dummy = 2
    xmin = 1000
    xmax = 9000
    ymin = -3
    ymax = 3
    fig, axes = plt.subplots(nrows=5, figsize=(5, 6),
            gridspec_kw=dict(height_ratios=[0.35, 0.05, 1, 1, 1]))
    fig2, axes2 = plt.subplots(nrows=5, figsize=(5, 6),
            gridspec_kw=dict(height_ratios=[0.35, 0.05, 1, 1, 1]))
    fig.subplots_adjust(wspace=0, hspace=0)
    fig2.subplots_adjust(wspace=0, hspace=0)
    axes_all = np.concatenate((axes[dummy:], axes2[dummy:]))
    dummy_axes = np.concatenate((axes[:dummy], axes2[:dummy]))
    for ax in axes_all:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        ax.axhline(y=0, color='k')
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)
        ax.set_xticks([])
        ax.set_yticks(np.linspace(ymin, ymax, 4, endpoint=False))
        ax.set_ylabel(r'$\Delta$ MAG', fontsize=10)
    axes[-1].set_xticks(np.linspace(xmin, xmax, 16, endpoint=False))
    axes2[-1].set_xticks(np.linspace(xmin, xmax, 16, endpoint=False))
    plt.setp(axes[-1].xaxis.get_majorticklabels(), rotation=30)
    plt.setp(axes2[-1].xaxis.get_majorticklabels(), rotation=30)
    axes[-1].set_xlabel(r'WL($\AA$)', fontsize=10)
    axes2[-1].set_xlabel(r'WL($\AA$)', fontsize=10)
    for ax in dummy_axes:
        for s in ax.spines.values():
            s.set_visible(False)
        ax.xaxis.set_ticks_position('none')
        ax.yaxis.set_ticks_position('none')
        ax.set_xticks([])
        ax.set_yticks([])
    for axes_i in [axes, axes2]:
        cbar = plt.colorbar(CS3, ticks=[], orientation='horizontal',
                            cax=axes_i[0])
        for s, lab in enumerate(f):
            cbar.ax.text(coords[s], 0.5, lab, fontsize=8, va='center',
                     ha='center', rotation=90,
                     transform=cbar.ax.transAxes)
    for (j, d) in enumerate(nplist):
         bf = (SED==j)
         if (j<3):
             k = j
             ax = axes[k+dummy]
             ax.text(8500, 2.1, nplist[j], {'color': 'k', 'fontsize': 10})
             for i in range(len(f)):
                  bb = np.where(FILTER[bf]==i)[0]
                  ax.scatter(wavetable[bb], dd[bb],  s=1, color=mycolor[i])
         else:
             k = j-3
             ax = axes2[k+dummy]
             ax.text(8500, 2.1, nplist[j], {'color': 'k', 'fontsize': 10})
             for i in range(len(f)):
                  bb = np.where(FILTER[bf]==i)[0]
                  ax.scatter(wavetable[bb], dd[bb], s=1, color=mycolor[i])
    fname = '.'.join([plot_name, nplist[0], nplist[1], nplist[2], 'png'])
    fig.savefig(fname)
    fname = '.'.join([plot_name, nplist[3], nplist[4], nplist[5], 'png'])
    fig2.savefig(fname)

if __name__=='__main__':
    a = np.loadtxt('calibration.photometry.information.capak.cat')
    Z_s = a[:, 0]
    CWL = a[:, 1]
    filter_id = a[:, 2]
    spectral_type = a[:, 3]
    model_mag = a[:, 4]
    mag = a[:, 5]
    plot_name = 'test'
    plot(Z_s, CWL, filter_id, spectral_type, model_mag, mag, plot_name)

给出:

enter image description here