如何比较数据并从熊猫的multiIndex数据框中选择TOP 2?

时间:2018-07-22 08:33:07

标签: python pandas

如何比较数据并从大熊猫的multiIndex数据框中选择TOP 2或TOP5?您会在此示例中看到,如果foo仅获得一条记录,则只会选择一条记录。但是,如果有三个记录,将选择TOP2记录。

例如:

arrays = [np.array(['bar', 'bar', 'bar', 'bar', 'baz','baz', 'baz', 'qux', 'qux','qux', 'qux','foo']),
          np.array(['AA', 'AB', 'AC','AD', 'BA', 'BB', 'BC', 'CA', 'CB', 'CC', 'CD', 'DA'])]
df = pd.DataFrame(np.random.randn(12, 1), index=arrays)
df

OUT:

         0
bar AA  -0.754077
    AB   0.924327
    AC   0.146192
    AD  -0.718730
baz BA  -0.143378
    BB   1.098409
    BC   0.703452
qux CA   0.729626
    CB   0.232755
    CC   0.827796
    CD   0.914639
foo DA  -0.289108

最后,我要这样选择:

         0
bar AB   0.924327
    AC   0.146192     
baz BB   1.098409
    BC   0.703452
qux CC   0.827796
    CD   0.914639
foo DA  -0.289108

1 个答案:

答案 0 :(得分:1)

使用:

np.random.seed(234)
arrays = [np.array(['bar', 'bar', 'bar', 'bar', 'baz','baz', 'baz', 'qux', 'qux','qux', 'qux','foo']),
          np.array(['AA', 'AB', 'AC','AD', 'BA', 'BB', 'BC', 'CA', 'CB', 'CC', 'CD', 'DA'])]
df = pd.DataFrame(np.random.randn(12, 1), index=arrays)
print (df)
               0
bar AA  0.818792
    AB -1.043551
    AC  0.350901
    AD  0.921578
baz BA -0.087382
    BB -3.128885
    BC -0.969733
qux CA  0.934666
    CB  0.043866
    CC  1.425216
    CD -0.557063
foo DA  0.926824

使用SeriesGroupBy.nlargest的解决方案:

s = df.groupby(level=0)[0].nlargest(2).reset_index(level=0, drop=True)
print (s)
bar  AD    0.921578
     AA    0.818792
baz  BA   -0.087382
     BC   -0.969733
foo  DA    0.926824
qux  CC    1.425216
     CA    0.934666
Name: 0, dtype: float64

如果需要避免对MultiIndex进行排序:

df1 = (df.groupby(level=0, sort=False)[0]
       .nlargest(2)
       .reset_index(level=0, drop=True)
       .to_frame())
print (df1)

               0
bar AD  0.921578
    AA  0.818792
baz BA -0.087382
    BC -0.969733
qux CC  1.425216
    CA  0.934666
foo DA  0.926824

另一种解决方案,可以与pandas 0.23.0+sort_values一起在GroupBy.head中使用:

df.index.names = ['lvl1','lvl2']
df.columns = ['a']
s = df.sort_values(['lvl1', 'a'], ascending=[True, False]).groupby(level=0).head(2)
print (s)
                  a
lvl1 lvl2          
bar  AD    0.921578
     AA    0.818792
baz  BA   -0.087382
     BC   -0.969733
foo  DA    0.926824
qux  CC    1.425216
     CA    0.934666
相关问题