对熊猫数据框进行分组并根据条件进行验证

时间:2018-11-14 08:18:11

标签: python-3.x pandas

数据框:

id   Base   field1    field2    field3
1     Y      AA         BB        CC
1     N      AA         BB        CC
1     N      AA         BB        CC     
2     Y      DD         EE        FF
2     N      OO         EE        WT
2     N      DD         JQ        FF
3     Y      MM         NN        TT
3     Y      MM         NN        TT 
3     N      MM         NN        TT

预期结果是根据ID列对该数据帧进行分组,应进行2次验证。

  1. 首先检查每个组中是否只有一个基本值“ Y”。如果仅是真的,则应将该行用作验证步骤2的参考,否则将错误写为“为ID找到多个基数Y”,并继续执行步骤1获取下一个ID

  2. 验证所有其他具有“ Base:N”的列上的数据是否与Base为“ Y”的列上的数据匹配,并在error列中写入不匹配的字段名称。产品栏是唯一字段,可以忽略以进行数据比较。

  3. 针对数据帧中的所有ID重复此操作。

预期结果是

id  product Base  field1  field2  field3   Error
1   A        Y     AA       BB      CC     Reference value
1   B        N     AA       BB      CC     Pass
1   C        N     AA       BB      CC     Pass
2   D        Y     DD       EE      FF     Reference value
2   E        N     OO       EE      WT     field1, field3 mismatch    
2   F        N     DE       JQ      FF     field1, field2 mismatch 
3   G        Y     MM       NN      TT     more than 1 Y found for id:
3   H        Y     MM       NN      TT     more than 1 Y found for id:
3   I        N     MM       NN      TT     more than 1 Y found for id:

对此有任何帮助吗?

1 个答案:

答案 0 :(得分:0)

使用自定义功能:

def f(x):
    #boolena mask for compare Y
    mask = x['Base'] == 'Y'
    #check multiple Y by sum of Trues
    if mask.sum() > 1:
        x['Error'] = 'more than 1 base Y found for id:{}'.format(x.name)
    else:
        #remove columns for not comparing with not equal
        cols = x.columns.difference(['Base','product'])
        mask1 = x[cols].ne(x.loc[mask, cols])
        #if difference get columns names by dot
        if mask1.values.any():
            vals = mask1.dot(mask1.columns + ', ').str.rstrip(', ') + ' mismatch with base' 
            x['Error'] = np.where(mask, 'Base: Y', vals)    
        else:
            x['Error'] = np.where(mask, 'Base: Y', 'Pass')    

    return x

df = df.groupby(level=0).apply(f)
print (df)
   product Base field1 field2 field3                              Error
id                                                                     
1        A    Y     AA     BB     CC                            Base: Y
1        B    N     AA     BB     CC                               Pass
1        C    N     AA     BB     CC                               Pass
2        D    Y     DD     EE     FF                            Base: Y
2        E    N     OO     EE     WT  field1, field3 mismatch with base
2        F    N     DD     JQ     FF          field2 mismatch with base
3        G    Y     MM     NN     TT  more than 1 base Y found for id:3
3        H    Y     MM     NN     TT  more than 1 base Y found for id:3
3        I    N     MM     NN     TT  more than 1 base Y found for id:3

示例数据框:

df = pd.DataFrame({'id': [1, 1, 1, 2, 2, 2, 3, 3, 3], 
                   'product': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'], 
                   'Base': ['Y', 'N', 'N', 'Y', 'N', 'N', 'Y', 'Y', 'N'], 
                   'field1': ['AA', 'AA', 'AA', 'DD', 'OO', 'DD', 'MM', 'MM', 'MM'], 
                   'field2': ['BB', 'BB', 'BB', 'EE', 'EE', 'JQ', 'NN', 'NN', 'NN'], 
                   'field3': ['CC', 'CC', 'CC', 'FF', 'WT', 'FF', 'TT', 'TT', 'TT']})
df = df.set_index('id')
print (df)
   product Base field1 field2 field3
id                                  
1        A    Y     AA     BB     CC
1        B    N     AA     BB     CC
1        C    N     AA     BB     CC
2        D    Y     DD     EE     FF
2        E    N     OO     EE     WT
2        F    N     DD     JQ     FF
3        G    Y     MM     NN     TT
3        H    Y     MM     NN     TT
3        I    N     MM     NN     TT