Matplotlib直方图与收集箱的高值

时间:2014-10-06 14:42:23

标签: python matplotlib histogram bins

我有一个包含值的数组,我想创建它的直方图。我主要对低端号码感兴趣,并希望在一个箱子中收集300以上的每个号码。此箱应具有与所有其他(同样宽)箱相同的宽度。我怎么能这样做?

注意:此问题与此问题有关:Defining bin width/x-axis scale in Matplotlib histogram

这是我到目前为止所尝试的:

import matplotlib.pyplot as plt
import numpy as np

def plot_histogram_01():
    np.random.seed(1)
    values_A = np.random.choice(np.arange(600), size=200, replace=True).tolist()
    values_B = np.random.choice(np.arange(600), size=200, replace=True).tolist()

    bins = [0, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275, 300, 600]

    fig, ax = plt.subplots(figsize=(9, 5))
    _, bins, patches = plt.hist([values_A, values_B], normed=1,  # normed is deprecated and will be replaced by density
                                bins=bins,
                                color=['#3782CC', '#AFD5FA'],
                                label=['A', 'B'])

    xlabels = np.array(bins[1:], dtype='|S4')
    xlabels[-1] = '300+'

    N_labels = len(xlabels)
    plt.xlim([0, 600])
    plt.xticks(25 * np.arange(N_labels) + 12.5)
    ax.set_xticklabels(xlabels)

    plt.yticks([])
    plt.title('')
    plt.setp(patches, linewidth=0)
    plt.legend()

    fig.tight_layout()
    plt.savefig('my_plot_01.png')
    plt.close()

这是结果,看起来不太好看: enter image description here

然后我用xlim更改了这行:

plt.xlim([0, 325])

具有以下结果: enter image description here

它或多或少看起来像我想要的那样,但现在看不到最后一个bin。我错过了哪个技巧可以看到最后一个宽度为25的bin?

2 个答案:

答案 0 :(得分:30)

Numpy有一个方便的功能来解决这个问题:np.clip。尽管名称可能听起来像,但它不会删除值,它只是将它们限制在您指定的范围内。基本上,它是Artem的“脏黑客”内联。您可以按原样保留值,但在hist调用中,只需将数组包装在np.clip调用中,就像这样

plt.hist(np.clip(values_A, bins[0], bins[-1]), bins=bins)

由于多种原因,这样做更好:

  1. 方式更快 - 至少对于大量元素而言。 Numpy在C级工作。在python列表上操作(如在Artem的列表理解中)对每个元素都有很多开销。基本上,如果你有选择使用numpy,你应该。

  2. 您可以在需要的地方执行此操作,从而减少在代码中出错的可能性。

  3. 您不需要保留数组的第二个副本,这样可以减少内存使用量(这一行除外)并进一步减少出错的可能性。

  4. 使用bins[0], bins[-1]而不是对值进行硬编码可以减少再次出错的可能性,因为您可以更改bins定义的位置;您无需记得在致电clip或其他任何地方时更改它们。

  5. 所以把它们放在一起就像在OP中一样:

    import matplotlib.pyplot as plt
    import numpy as np
    
    def plot_histogram_01():
        np.random.seed(1)
        values_A = np.random.choice(np.arange(600), size=200, replace=True)
        values_B = np.random.choice(np.arange(600), size=200, replace=True)
    
        bins = np.arange(0,350,25)
    
        fig, ax = plt.subplots(figsize=(9, 5))
        _, bins, patches = plt.hist([np.clip(values_A, bins[0], bins[-1]),
                                     np.clip(values_B, bins[0], bins[-1])],
                                    # normed=1,  # normed is deprecated; replace with density
                                    density=True,
                                    bins=bins, color=['#3782CC', '#AFD5FA'], label=['A', 'B'])
    
        xlabels = bins[1:].astype(str)
        xlabels[-1] += '+'
    
        N_labels = len(xlabels)
        plt.xlim([0, 325])
        plt.xticks(25 * np.arange(N_labels) + 12.5)
        ax.set_xticklabels(xlabels)
    
        plt.yticks([])
        plt.title('')
        plt.setp(patches, linewidth=0)
        plt.legend(loc='upper left')
    
        fig.tight_layout()
    plot_histogram_01()
    

    result of code above

答案 1 :(得分:5)

抱歉,我不熟悉matplotlib。所以我对你有一个肮脏的黑客。我只是将所有大于300的值放在一个bin中并更改了bin大小。

问题的根源在于matplotlib试图将所有垃圾箱放在图上。在R中我会将我的箱子转换为因子变量,因此它们不被视为实数。

import matplotlib.pyplot as plt
import numpy as np

def plot_histogram_01():
    np.random.seed(1)
    values_A = np.random.choice(np.arange(600), size=200, replace=True).tolist()
    values_B = np.random.choice(np.arange(600), size=200, replace=True).tolist()
    values_A_to_plot = [301 if i > 300 else i for i in values_A]
    values_B_to_plot = [301 if i > 300 else i for i in values_B]

    bins = [0, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325]

    fig, ax = plt.subplots(figsize=(9, 5))
    _, bins, patches = plt.hist([values_A_to_plot, values_B_to_plot], normed=1,  # normed is deprecated and will be replaced by density
                                bins=bins,
                                color=['#3782CC', '#AFD5FA'],
                                label=['A', 'B'])

    xlabels = np.array(bins[1:], dtype='|S4')
    xlabels[-1] = '300+'

    N_labels = len(xlabels)

    plt.xticks(25 * np.arange(N_labels) + 12.5)
    ax.set_xticklabels(xlabels)

    plt.yticks([])
    plt.title('')
    plt.setp(patches, linewidth=0)
    plt.legend()

    fig.tight_layout()
    plt.savefig('my_plot_01.png')
    plt.close()

plot_histogram_01()

enter image description here