熊猫如何将函数应用于带参数的groupby对象

时间:2019-02-26 17:00:18

标签: python pandas dataframe pandas-groupby

我有一个df

cluster_id    memo
   1          m
   1          n
   2          m
   2          m
   2          n
   3          m
   3          m
   3          m
   3          n
   4          m
   4          n
   4          n
   4          n

我要groupby cluster_id并应用以下功能,

def valid_row_dup(df):
    num_real_invs = df[df['memo'] == 'm'].shape[0]
    num_reversals_invs = df[df['memo'] == 'n'].shape[0]

    if num_real_invs == df.shape[0]:
        return True
    elif num_reversals_invs == df.shape[0]:
        return False
    elif abs(num_real_invs - num_reversals_invs) > 0:
        # even diff
        if abs(num_real_invs - num_reversals_invs) % 2 == 0:
            return True
        else:
            if abs(num_real_invs - num_reversals_invs) == 1:
                return False
            # odd diff
            else:
                return True
    elif num_real_invs - num_reversals_invs == 0:
        return False 

将每个groupby对象作为df传递到func中;将布尔结果分配回df

cluster_id    memo     valid
   1          m        False
   1          n        False
   2          m        False
   2          m        False
   2          n        False
   3          m        True
   3          m        True
   3          m        True
   3          n        True
   4          m        True
   4          n        True
   4          n        True   
   4          n        True

3 个答案:

答案 0 :(得分:1)

应用您的函数然后合并:

df.merge(df.groupby('cluster_id').apply(valid_row_dup).to_frame(), on='cluster_id')

    cluster_id memo      0
0            1    m  False
1            1    n  False
2            2    m  False
3            2    m  False
4            2    n  False
5            3    m   True
6            3    m   True
7            3    m   True
8            3    n   True
9            4    m   True
10           4    n   True
11           4    n   True
12           4    n   True

答案 1 :(得分:1)

我同意克里斯的回答。 只是想提供一个完善的解决方案。

df.merge(df.groupby('cluster_id').apply(valid_row_dup).\
    to_frame().reset_index().\
    rename(columns={0:'valid'}),
    on='cluster_id', how='inner')

答案 2 :(得分:1)

如果您通过其他方式定义函数:

def valid_row_dup2(ser):
    num_real_invs = ser[ser == 'm'].size        # Number of 'm'
    num_reversals_invs = ser[ser == 'n'].size   # Number of 'n'
    siz = ser.size                  # Total size
    diff = abs(num_real_invs - num_reversals_invs)
    if num_real_invs == siz:        # Only 'm'
        return True
    elif num_reversals_invs == siz: # Only 'n'
        return False
    elif diff > 0:          # Different number of 'm' and 'n'
        if diff % 2 == 0:   # Even diff
            return True
        elif diff == 1:     # Difference by one
            return False
        else:               # Odd diff, > 1
            return True
    else:                   # Equal number of 'm' and 'n'
        return False

您可以如下添加新列:

df['valid'] = df.groupby('cluster_id').memo.transform(valid_row_dup2)

恕我直言,这是一个更简单的解决方案(没有merge,您只需添加一个新列)。

相关问题