如何制作具有一致且对齐良好的垃圾箱的分组直方图?

时间:2020-07-10 13:04:34

标签: python matplotlib

我想对存储在pandas DataFrame中的数据进行直方图绘制,其中直方图根据该数据帧中的另一列分为两组(我们将其称为{{1} }列,可以是1或0)。我很难让两个小组的垃圾箱以合理的方式对齐。

这是我到目前为止所拥有的:

target

这是结果:

def fun_histByTarget(df, cols, target):
    target = df[target]
    if isinstance(cols, str):
        cols = [cols]
    fig = plt.figure(figsize=(18, 5 * ((len(cols) + 1) // 2)), dpi= 80)
    for i in range(len(cols)):
        sp = fig.add_subplot((len(cols) + 1) // 2, 2, i + 1)
        col = df[cols[i]].copy()
        sp.hist(col[target==0], color='red',  alpha=.3, label='target = 0', align='left')
        sp.hist(col[target==1], color='blue', alpha=.3, label='target = 1', align='left')
        sp.legend()
        sp.set_title(cols[i])

result_1

我尝试手动添加垃圾箱

fun_histByTarget(test, 'integer_col', 'target')

但这没有帮助。结果仓的选择非常奇怪,因此即使所有数据都是整数,直方图的某些条也完全落在两个整数值之间。可能是因为我已经硬编码了10个垃圾箱。但是,很难自动选择正确数量的垃圾箱。有更好的方法吗?

1 个答案:

答案 0 :(得分:1)

要为两者获得相同的直方图块,只需使用具有完全相同边界的bins=参数即可。因此,目前尚不清楚为什么您的测试无法正常进行。 (很难看到没有使用的确切代码。)

除此之外,列名'integer_col'暗示只有整数的列。直方图主要用于处理连续数据。如果只有整数,并且将bin边界创建为np.linspace(1, 7, 10),则将在[1.0, 1.667, 2.333, 3.0, 3.667, 4.333, 5.0, 5.667, 6.333, 7.0]上有9个带有怪异边界的bin。因此,整数值1会落在第一个bin中,值2会落在第二个bin中,值3会落在第三个或第四个bin中(取决于浮点舍入误差),...更方便的bin选择是{ 1}},如下面的代码所示。 (我还将0.5, 1.5, 2.5, ...更改为默认的align='left',以使条形图与其相应值位于同一位置。)

align='mid'

resulting plot

如果要避免重叠的条形图,则条形图具有更多选项,但是您需要在单独的步骤中计算计数(例如,使用import matplotlib.pyplot as plt import numpy as np import pandas as pd def fun_histByTarget(df, cols, target): target = df[target] if isinstance(cols, str): cols = [cols] fig = plt.figure(figsize=(18, 5 * ((len(cols) + 1) // 2)), dpi=80) for i in range(len(cols)): ax = fig.add_subplot((len(cols) + 1) // 2, 2, i + 1) col = df[cols[i]] bins = np.arange(col.min() - 0.5, col.max() + 0.5001, (col.max() - col.max()) // 20 + 1) ax.hist(col[target == 0], bins=bins, color='red', alpha=.3, label='target = 0', align='mid') ax.hist(col[target == 1], bins=bins, color='blue', alpha=.3, label='target = 1', align='mid') ax.legend() ax.set_title(cols[i]) target = np.random.randint(0, 2, 100) integer_col = np.where(target == 0, np.random.randint(1, 7, target.size), np.random.randint(1, 6, target.size)) test = pd.DataFrame({'integer_col': integer_col, 'target': target}) fun_histByTarget(test, 'integer_col', 'target') plt.show() 或通过np.hist)。 / p>