在python pandas中应用带有shift函数的lambda,需要替换一些null元素

时间:2015-01-19 15:11:04

标签: python pandas lambda

我正在尝试在数据框中执行以下操作。 如果Period不是1,则更改Column Attrition的值,然后通过groupby中上面的行中的attrition值将该行中保留列的值加倍。我的尝试如下:

import pandas as pd

data = {'Country': ['DE', 'DE', 'DE', 'US', 'US', 'US', 'FR', 'FR', 'FR'],
    'Week': ['201426', '201426', '201426', '201426', '201425', '201425', '201426', '201426', '201426'],
    'Period': [1, 2, 3, 1, 1, 2, 1, 2, 3],
    'Attrition': [0.5,'' ,'' ,0.85 ,0.865,'' ,0.74 ,'','' ],
    'Retention': [0.95,0.85,0.94,0.85,0.97,0.93,0.97,0.93,0.94]}

df = pd.DataFrame(data, columns= ['Country', 'Week', 'Period', 'Attrition','Retention'])
print df

给我这个输出:

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2                 0.85
2      DE  201426       3                 0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2                 0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2                 0.93
8      FR  201426       3                 0.94

以下内容:

df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: x.Attrition.shift(1)*x['Retention'] if x.Period != 1 else x.Attrition)

print df

给了我这个错误:

df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: x.Attrition.shift(1)*x['Retention'] if x.Period != 1 else x.Attrition)

ValueError:具有多个元素的数组的真值是不明确的。使用a.any()或a.all()

更新:完成编译解决方案

下面是我之前的完整工作解决方案,其中基本上使用了Primer的答案,但添加了一个while循环以继续在数据帧列上运行Lambda函数,直到没有更多的NaN。

import pandas as pd
import numpy as np

data = {'Country': ['DE', 'DE', 'DE', 'US', 'US', 'US', 'FR', 'FR', 'FR'],
    'Week': ['201426', '201426', '201426', '201426', '201425', '201425', '201426', '201426', '201426'],
    'Period': [1, 2, 3, 1, 1, 2, 1, 2, 3],
    'Attrition': [0.5, '' ,'' ,0.85 ,0.865,'' ,0.74 ,'','' ],
    'Retention': [0.95,0.85,0.94,0.85,0.97,0.93,0.97,0.93,0.94]}

df = pd.DataFrame(data, columns= ['Country', 'Week', 'Period', 'Attrition','Retention'])
print df

输出:启动DF

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2                 0.85
2      DE  201426       3                 0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2                 0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2                 0.93
8      FR  201426       3                 0.94

解决方案:

#Replaces empty string with NaNs
df['Attrition'] = df['Attrition'].replace('', np.nan)

#Stores a count of the number of null or NaNs in the column.
ContainsNaN = df['Attrition'].isnull().sum()

#run the loop while there are some NaNs in the column.
while ContainsNaN > 0:    
    df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: pd.Series(np.where((x.Period != 1), x.Attrition.shift() * x['Retention'], x.Attrition)))        
    ContainsNaN = df['Attrition'].isnull().sum()

print df

输出结果

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2     0.425       0.85
2      DE  201426       3    0.3995       0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2   0.80445       0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2    0.6882       0.93
8      FR  201426       3  0.646908       0.94

1 个答案:

答案 0 :(得分:2)

首先,您的Attrition列将数字数据与空字符串''混合在一起,这通常不是一个好主意,应该在尝试对此列进行计算之前修复:

df.loc[df['Attrition'] == '', 'Attrition'] = pd.np.nan
df['Attrition'] = df.Attrition.astype('float')

您得到的错误是因为.applyx.Period != 1中的条件产生了一个布尔数组:

0    False
1     True
2     True
3    False
4    False
5     True
6    False
7     True
8     True
Name: Period, dtype: bool

哪个.apply不知道如何处理,因为它的含糊不清(即在这种情况下应该是什么?)。

您可以考虑numpy.where执行此任务:

import numpy as np
g = df.groupby(['Country','Week'], as_index=0, group_keys=0)
df['Attrition'] = g.apply(lambda x: pd.Series(np.where((x.Period != 1), x.Attrition.shift() * x['Retention'], x.Attrition)).fillna(method='ffill')).values
df

得到以下特性:

  Country    Week  Period  Attrition  Retention
0      DE  201426       1      0.500       0.95
1      DE  201426       2      0.425       0.85
2      DE  201426       3      0.425       0.94
3      US  201426       1      0.740       0.85
4      US  201425       1      0.688       0.97
5      US  201425       2      0.688       0.93
6      FR  201426       1      0.865       0.97
7      FR  201426       2      0.804       0.93
8      FR  201426       3      0.850       0.94

请注意,我添加了.fillna方法,该方法将NaN填入上次观察到的值。