基于pandas组的多个箱图

时间:2017-06-29 06:38:41

标签: python pandas

以下是我的数据框的样子:

year    item_id      sales_quantity
 2014     1            10
 2014     1             4
 ...      ...          ...

 2015     1             7
 2015     1             10
 ...     ...          ...
 2014     2             1
 2014     2             8
 ...      ...          ...

 2015     2             17
 2015     2             30
 ...     ...          ...
 2014     3             9
 2014     3             18
 ...     ...          ...

对于每个item_id,我想绘制一个显示每年分布的箱线图。

以下是我的尝试:

data = pd.DataFrame.from_csv('electronics.csv')
grouped = data.groupby(['year'])
ncols=4
nrows = int(np.ceil(grouped.ngroups/ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(35,45), 
sharey=False)

for (key, ax) in zip(grouped.groups.keys(), axes.flatten()):
    grouped.get_group(key).boxplot(x='year', y='sales_quantity', 
    ax=ax, label=key)

我收到错误boxplot() got multiple values for argument 'x'。有人可以告诉我如何做到这一点吗?

如果我只有一个项目,那么以下工作 sns.boxplot(data.sales_quantity, groupby = data.year)。我怎么能为多个项目扩展呢?

Link to csv

2 个答案:

答案 0 :(得分:0)

我将为其他人留下这个简单的版本......

import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_table('sample.txt', delimiter='\s+')

fig, axes = plt.subplots(1, 3, sharey=True)
for n, i in enumerate(df['item_id'].unique()):
    idf = df[df['item_id'] == int('{}'.format(i))][['year', 'sales_quantity']].pivot(columns='year')
    print(idf)

    idf.plot.box(ax=axes[n])
    axes[n].set_title('Item ID {}'.format(i))
    axes[n].set_xticklabels([e[1] for e in idf.columns])

plt.show()

sample.txt的

year    item_id      sales_quantity
 2014     1            10
 2014     1             4
 2015     1             7
 2015     1             10
 2014     2             1
 2014     2             8
 2015     2             17
 2015     2             30
 2014     3             9
 2014     3             18

enter image description here

答案 1 :(得分:0)

请检查代码的评论。

import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('electronics_157_3cols.csv')
print(df)

fig, axes = plt.subplots(1, len(df['item_id_copy'].unique()), sharey=True)
for n, i in enumerate(df['item_id_copy'].unique()):
    idf = df[df['item_id_copy'] == int('{}'.format(i))][['year', 'sales_quantity']].pivot(columns='year')
    print(idf)

    idf.plot.box(ax=axes[n])
    axes[n].set_title('ID {}'.format(i))
    axes[n].set_xticklabels([e[1] for e in idf.columns], rotation=45)
    axes[n].set_ylim(0, 1)  # You should disable this line to specify outlier properly. (but I didn't to show you a normal graph)

plt.show()

enter image description here

import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('electronics_157_3cols.csv')
print(df)

fig, axes = plt.subplots(2, 5, sharey=True)

gen_n  = (n for n in range(1, 11))
gen_i = (i for i in df['item_id_copy'].unique())

for r in range(2):
    for c in range(5):
        n = gen_n.__next__()
        i = gen_i.__next__()
        idf = df[df['item_id_copy'] == int('{}'.format(i))][['year', 'sales_quantity']].pivot(columns='year')
        print(idf)

        idf.plot.box(ax=axes[r][c])
        axes[r][c].set_title('ID {}'.format(i))
        axes[r][c].set_xticklabels([e[1] for e in idf.columns], rotation=0)
        axes[r][c].set_ylim(0, 1)

plt.show()        

enter image description here