下面是月度患者活动的数据框
rdd = sc.parallelize([("00000000000087052962",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1),
("00000000000087052963",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1)]
)
df=rdd.toDF(['PTNT_GID','m_200402','m_200407','m_200408','m_200501','m_200503','m_200504','m_200505','m_200506','m_200508','m_200509','m_200512','m_200604','m_200605','m_200607','m_200608','m_200609','m_200611','m_200612','m_200701','m_200703','m_200705','m_200708','m_200709','m_200710','m_200711','m_200712','m_200801','m_200803','m_200804','m_200805','m_200806','m_200807','m_200808','m_200809','m_200810','m_200811','m_200812','m_200901','m_200902','m_200903','m_200904','m_200905','m_200906','m_200907','m_200908','m_200909','m_200910','m_200911','m_200912','m_201001','m_201002','m_201003','m_201004','m_201005','m_201006','m_201007','m_201008','m_201009','m_201010','m_201011','m_201012','m_201101','m_201102','m_201103','m_201104','m_201105','m_201106','m_201107','m_201108','m_201109','m_201110','m_201111','m_201112','m_201201','m_201203','m_201204','m_201205','m_201206','m_201207','m_201208','m_201209','m_201210','m_201211','m_201212','m_201301','m_201302','m_201303','m_201304','m_201305','m_201306','m_201307','m_201308','m_201309','m_201310','m_201311','m_201312','m_201401','m_201402','m_201403','m_201404','m_201405','m_201406','m_201407','m_201408','m_201409','m_201410','m_201411','m_201412','m_201501','m_201502','m_201503','m_201504','m_201505','m_201506','m_201507','m_201508','m_201509','m_201510','m_201511','m_201512','m_201601','m_201602','m_201603','m_201604','m_201605','m_201606','m_201607','m_201608','m_201609','m_201610','m_201611','m_201612','m_201701','m_201702','m_201703'])
使用案例:我希望过去36个月跟踪患者活动。患者应该在过去36个月中每六个月有一次活动(数据框中提到的标志为1或0),然后为该月设置的激活标志为1,否则为0.
我写了下面的逻辑来修改数据帧。开始月份是m_200402,结束月份是m_201703。不需要检查每个患者的第一个36个月我必须跟踪37个月的活动。def chunkify(alist, wanted_parts):
length = len(alist)
return [ alist[i*length // wanted_parts: (i+1)*length // wanted_parts]
for i in range(wanted_parts) ]
result = []
#result.append(df.columns)
for i,data in enumerate(df.rdd.map(list).toLocalIterator()):
result.append(data)
for j,val in enumerate(data):
if( j > 37):
falg = 1
for jndex,ts in enumerate(chunkify(data[j-37:j-1], 6)):
if 1 in ts:
flag = 1
result[i][j] = 1
else:
flag = 0
result[i][j] = 0
continue
result =[df.columns] + result
我想在pyspark中使用lambda函数修改数据框本身的上述逻辑。
答案 0 :(得分:1)
您应该展开数据框,以便按PTNT_GID, month
排成一行,然后应用窗口函数。
import pyspark.sql.functions as psf
from itertools import chain
df_expl = df.select(
'PTNT_GID',
psf.posexplode(psf.create_map(list(chain(*[(psf.lit(c), psf.col(c)) for c in df.columns if c != 'PTNT_GID'])))))
+--------------------+---+--------+-----+
| PTNT_GID|pos| key|value|
+--------------------+---+--------+-----+
|00000000000087052962| 0|m_200402| 0|
|00000000000087052962| 1|m_200407| 0|
|00000000000087052962| 2|m_200408| 0|
|00000000000087052962| 3|m_200501| 0|
|00000000000087052962| 4|m_200503| 0|
|00000000000087052962| 5|m_200504| 0|
|00000000000087052962| 6|m_200505| 0|
|00000000000087052962| 7|m_200506| 0|
|00000000000087052962| 8|m_200508| 0|
|00000000000087052962| 9|m_200509| 0|
|00000000000087052962| 10|m_200512| 0|
|00000000000087052962| 11|m_200604| 0|
|00000000000087052962| 12|m_200605| 0|
|00000000000087052962| 13|m_200607| 0|
|00000000000087052962| 14|m_200608| 0|
|00000000000087052962| 15|m_200609| 0|
|00000000000087052962| 16|m_200611| 0|
|00000000000087052962| 17|m_200612| 0|
|00000000000087052962| 18|m_200701| 1|
|00000000000087052962| 19|m_200703| 1|
+--------------------+---+--------+-----+
现在我们可以应用一个窗口函数了。如果我理解正确,你将过去的36个月分成6个月的6个月。当且仅当6个块中的每个块包含至少一个时,最终值为1.这转换为6个月内最大值的6个块的最小值
from pyspark.sql import Window
w = Window.partitionBy('PTNT_GID').orderBy('pos')
res = df_expl.select(
"*",
psf.least(
*[psf.max('value').over(w.rowsBetween(-(i+1)*6 - 1, -i*6 - 1)) for i in range(6)]
).alias("act_6m")
).na.fill(0)
+--------------------+---+--------+-----+------+
| PTNT_GID|pos| key|value|act_6m|
+--------------------+---+--------+-----+------+
|00000000000087052962| 0|m_200402| 0| 0|
|00000000000087052962| 1|m_200407| 0| 0|
|00000000000087052962| 2|m_200408| 0| 0|
|00000000000087052962| 3|m_200501| 0| 0|
|00000000000087052962| 4|m_200503| 0| 0|
|00000000000087052962| 5|m_200504| 0| 0|
|00000000000087052962| 6|m_200505| 0| 0|
|00000000000087052962| 7|m_200506| 0| 0|
|00000000000087052962| 8|m_200508| 0| 0|
|00000000000087052962| 9|m_200509| 0| 0|
|00000000000087052962| 10|m_200512| 0| 0|
|00000000000087052962| 11|m_200604| 0| 0|
|00000000000087052962| 12|m_200605| 0| 0|
|00000000000087052962| 13|m_200607| 0| 0|
|00000000000087052962| 14|m_200608| 0| 0|
|00000000000087052962| 15|m_200609| 0| 0|
|00000000000087052962| 16|m_200611| 0| 0|
|00000000000087052962| 17|m_200612| 0| 0|
|00000000000087052962| 18|m_200701| 1| 0|
|00000000000087052962| 19|m_200703| 1| 0|
+--------------------+---+--------+-----+------+