在MatplotLib条形图中创建随图的大小缩放的文本

时间:2019-10-15 10:58:41

标签: python matplotlib

我正在为GTD数据集制作动画的条形图竞赛。问题是不同组的名称对于条形图而言太长。有没有一种方法可以缩放与条形值相对应的每个条形标签?我的意思是使短条标签变小,使长条标签变大。

这是我正在使用的代码:

fig, ax = plt.subplots(figsize=(15, 8))
    """Takes the dataframe till the given year, groups by group and sums up the number of casulties. Sorts the values and returns the highest 10"""
    top10 = df2[df2['iyear'] <= year].groupby(['gname'], as_index=False)[['wound_killed', 'country_txt']].sum().sort_values('wound_killed', ascending=False).head(10)
    top10 = top10[::-1]  #Flip Dataframe
    ax.barh(top10['gname'], top10['wound_killed'])  #Create horizontal bar plot


    dx = top10['wound_killed'].max() / 200     #Get the value of the end of each graph
    for i, (value, name) in enumerate(zip(top10['wound_killed'], top10['gname'])):
        ax.text(value-dx, i,     name,           size=14, weight=600, ha='right', va='bottom')   #Label the name of the group
        ax.text(value+dx, i,     f'{value:,.0f}',  size=14, ha='left',  va='center')  #Number of casulties
    # ... polished styles
    ax.text(1, 0.4, year, transform=ax.transAxes, color='#777777', size=46, ha='right', weight=800)    #Write the year
    ax.text(0, 1.06, 'Number of Casulties', transform=ax.transAxes, size=12, color='#777777')     #Number of casulties on top
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))    #Updating the xaxis when it scales
    ax.xaxis.set_ticks_position('top')   #XTicks on top rather than bottom
    ax.tick_params(axis='x', colors='#777777', labelsize=12)
    ax.set_yticks([])
    ax.margins(0, 0.01)
    ax.grid(which='major', axis='x', linestyle='-')
    ax.set_axisbelow(True)
    ax.text(0, 1.12, 'The devastating Terror Groups from 1970 to 2018',   #Titel
            transform=ax.transAxes, size=24, weight=600, ha='left')

    plt.box(False)   #No box because labels are too big

1 个答案:

答案 0 :(得分:-1)

找到了一种方法。要结束这个问题。张贴答案,以防万一其他人也想知道

#fig, ax = plt.subplots(figsize=(15, 8))
year = 1970
def top_10_year(year):
    fig, ax = plt.subplots(figsize=(28, 10))
    """Takes the dataframe till the given year, groups by group and sums up the number of casulties. Sorts the values and returns the highest 10"""
    top10 = df2[df2['iyear'] <= year].groupby(['gname'], as_index=False)[['wound_killed', 'country_txt']].sum().sort_values('wound_killed', ascending=False).head(10)
    top10 = top10[::-1]  #Flip Dataframe
    ax.barh(top10['gname'], top10['wound_killed'])  #Create horizontal bar plot


    dx = top10['wound_killed'].max() / 200     #Get the value of the end of each graph

    for i, (value, name) in enumerate(zip(top10['wound_killed'], top10['gname'])):
        if '(' in name:
            name = name[name.find("(")+1:name.find(")")] 
        if value / top10['wound_killed'].max() < 0.1:
            ax.text(value-dx, i,     name,           size=8, weight=600, ha='right', va='bottom')   #Label the name of the group
        elif value / top10['wound_killed'].max() < 0.2:
            ax.text(value-dx, i,     name,           size=10, weight=600, ha='right', va='bottom')   #Label the name of the group
        elif value / top10['wound_killed'].max() < 0.3:
            ax.text(value-dx, i,     name,           size=12, weight=600, ha='right', va='bottom')   #Label the name of the group
        elif value / top10['wound_killed'].max() < 0.4:
            ax.text(value-dx, i,     name,           size=14, weight=600, ha='right', va='bottom')   #Label the name of the group
        else:
            ax.text(value-dx, i,     name,           size=16, weight=600, ha='right', va='bottom')   #Label the name of the group
        ax.text(value+dx, i,     f'{value:,.0f}',  size=16, ha='left',  va='center')  #Number of casulties
    # ... polished styles
    ax.text(1, 0.4, year, transform=ax.transAxes, color='#777777', size=46, ha='right', weight=800)    #Write the year
    ax.text(0, 1.06, 'Number of Casulties', transform=ax.transAxes, size=12, color='#777777')     #Number of casulties on top
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))    #Updating the xaxis when it scales
    ax.xaxis.set_ticks_position('top')   #XTicks on top rather than bottom
    ax.tick_params(axis='x', colors='#777777', labelsize=12)
    ax.set_yticks([])
    ax.margins(0, 0.01)
    ax.grid(which='major', axis='x', linestyle='-')
    ax.set_axisbelow(True)
    ax.text(0, 1.12, 'The devastating Terror Groups from 1970 to 2018',   #Titel
            transform=ax.transAxes, size=24, weight=600, ha='left')

    plt.box(False)   #No box because labels are too big
    num_im = str(year - 1969)
    save = '../plots/manual_animation/' + num_im + '.jpg'
    plt.savefig(save)

top_10_year(1970) #Test for one year

最后我实际上使用了另一种方法,根据标签的长度和条形图选择条形图的侧面。

year = 1970
jtplot.style(theme='chesterish')
def top_10_year2(year):
    fig, ax = plt.subplots(figsize=(28, 10))
    """Takes the dataframe till the given year, groups by group and sums up the number of casulties. Sorts the values and returns the highest 10"""
    top10 = df2[df2['iyear'] <= year].groupby(['gname'], as_index=False)[['wound_killed', 'country_txt']].sum().sort_values('wound_killed', ascending=False).head(10)
    top10 = top10[::-1]  #Flip Dataframe
    ax.barh(top10['gname'], top10['wound_killed'])  #Create horizontal bar plot


    dx = top10['wound_killed'].max() / 200     #Get the value of the end of each graph

    for i, (value, name) in enumerate(zip(top10['wound_killed'], top10['gname'])):
        if '(' in name:
            name = name[name.find("(")+1:name.find(")")] 
        if value / top10['wound_killed'].max() > 0.5 and len(name) < 40:
            ax.text(value-dx, i,     name,           size=16, weight=600, ha='right', va='bottom')   #Label the name of the group
            ax.text(value+dx, i,     f'{value:,.0f}',  size=16, ha='left',  va='center')  #Number of casulties
        elif value / top10['wound_killed'].max() > 0.5 and len(name) >= 40:
            ax.text(value-dx, i,     name,           size=14, weight=600, ha='right', va='bottom')   #Label the name of the group
            ax.text(value+dx, i,     f'{value:,.0f}',  size=16, ha='left',  va='center')  #Number of casulties
        elif value / top10['wound_killed'].max() <= 0.5 and len(name) < 40:
            ax.text(value+dx, i,     name,           size=16, weight=600, ha='left', va='bottom')   #Label the name of the group
            ax.text(value-dx, i,     f'{value:,.0f}',  size=16, ha='right',  va='center')  #Number of casulties
        elif value / top10['wound_killed'].max() <= 0.5 and len(name) >= 40:
            if value / top10['wound_killed'].max() < 0.3:
                ax.text(value+dx, i,     name,           size=16, weight=600, ha='left', va='bottom')   #Label the name of the group
            else:
                ax.text(value+dx, i,     name,           size=14, weight=600, ha='left', va='bottom')   #Label the name of the group
            ax.text(value-dx, i,     f'{value:,.0f}',  size=16, ha='right',  va='center')  #Number of casulties
        else:
            ax.text(value-dx, i,     name,           size=4, weight=600, ha='right', va='bottom')   #Label the name of the group
            ax.text(value-dx, i,     f'{value:,.0f}',  size=16, ha='right',  va='center')  #Number of casulties


    # ... polished styles
    ax.text(1, 0.4, year, transform=ax.transAxes, color='#777777', size=46, ha='right', weight=800)    #Write the year
    ax.text(0, 1.06, 'Number of Casulties', transform=ax.transAxes, size=18, color='#777777')     #Number of casulties on top
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))    #Updating the xaxis when it scales
    ax.xaxis.set_ticks_position('top')   #XTicks on top rather than bottom
    ax.tick_params(axis='x', colors='#777777', labelsize=16)
    ax.set_yticks([])
    ax.margins(0, 0.01)
    ax.grid(which='major', axis='x', linestyle='-')
    ax.set_axisbelow(True)
    ax.text(0, 1.12, 'The devastating Terror Groups from 1970 to 2018',   #Titel
            transform=ax.transAxes, size=24, weight=600, ha='left')

    plt.box(False)   #No box because labels are too big
    num_im = str(year - 1969)
    save = '../plots/label_left/' + num_im + '.jpg'
    plt.savefig(save)

top_10_year2(1985) #Test for one year