我有几个像这样的数据框:
rdd_1 = sc.parallelize([(0,10,"A",2), (1,20,"B",1), (2,30,"A",2)])
rdd_2 = sc.parallelize([(0,10,223,"201601"), (0,10,83,"2016032"),(1,20,3213,"201602"),(1,20,3003,"201601"), (1,20,9872,"201603"), (2,40, 2321,"201601"), (2,30, 10,"201602"),(2,61, 2321,"201601")])
df_tg = sqlContext.createDataFrame(rdd_1, ["id", "type", "route_a", "route_b"])
df_data = sqlContext.createDataFrame(rdd_2, ["id", "type", "cost", "date"])
df_tg.show()
+---+----+-------+-------+
| id|type|route_a|route_b|
+---+----+-------+-------+
| 0|10 | A| 2|
| 1|20 | B| 1|
| 2|30 | A| 2|
+---+----+-------+-------+
df_data.show()
+---+----+----+------+
| id|type|cost| date|
+---+----+----+------+
| 0|10 | 223|201603|
| 0|10 | 83 |201602|
| 1|20 |3003|201601|
| 1|20 |3213|201602|
| 1|20 |9872|201603|
| 2|30 | 10|201602|
| 2|30 | 62|201601|
| 2|40 |2321|201601|
+---+----+----+------+
所以我需要添加这样的列:
+---+----+-------+-------+-----------+-----------+-----------+
| id|type|route_a|route_b|cost_201603|cost_201602|cost_201601|
+---+----+-------+-------+-----------+-----------+-----------+
| 0|10 | A| 2| 223 | 83 | None|
| 1|20 | B| 1| 9872 | 3213 | 3003|
| 2|30 | A| 2| None | 10 | 62|
+---+----+-------+-------+-----------+-----------+-----------+
为此我必须做几个连接:
df_tg = df_tg.join(df_data[df_data.date == "201603"], ["id", "type"])
并且我必须重命名列,不要覆盖它们:
df_tg = df_tg.join(df_data[df_data.date == "201603"], ["id", "type"]).withColumnRenamed("cost","cost_201603")
我可以编写一个函数来执行此操作,但我必须遍历可用日期和列,通过全表扫描生成大量连接:
def feature_add(df_target, df_feat, feat_cols, period):
for ref_month in period:
df_target = df_target.join(df_feat, ["id", "type"]).select(
*[df_target[column] for column in df_target.columns] + [df_feat[feat_col]]
).withColumnRenamed(feat_col, feat_col + '_' + ref_month)
return df_target
df_tg = feature_add(df_tg, df_data, ["cost"], ["201602", "201603", "201601"])
这很有效,但很糟糕。如何添加这些列,包括当我为其他数据帧调用相同的函数时?请注意,列未完全对齐,我需要进行内连接。
答案 0 :(得分:3)
我建议使用枢轴功能如下:
from pyspark.sql.functions import *
rdd_1 = sc.parallelize([(0,10,"A",2), (1,20,"B",1), (2,30,"A",2)])
rdd_2 = sc.parallelize([(0,10,223,"201601"), (0,10,83,"2016032"),(1,20,3213,"201602"),(1,20,3003,"201601"), (1,20,9872,"201603"), (2,40, 2321,"201601"), (2,30, 10,"201602"),(2,61, 2321,"201601")])
df_tg = sqlContext.createDataFrame(rdd_1, ["id", "type", "route_a", "route_b"])
df_data = sqlContext.createDataFrame(rdd_2, ["id", "type", "cost", "date"])
pivot_df_data = df_data.groupBy("id","type").pivot("date").agg({"cost" : "sum"})
pivot_df_data.join(df_tg, ['id','type'], 'inner').select('id','type','route_a','route_b','201601','201602','201603','2016032').show()
# +---+----+-------+-------+------+------+------+-------+
# | id|type|route_a|route_b|201601|201602|201603|2016032|
# +---+----+-------+-------+------+------+------+-------+
# | 0| 10| A| 2| 223| null| null| 83|
# | 1| 20| B| 1| 3003| 3213| 9872| null|
# | 2| 30| A| 2| null| 10| null| null|
# +---+----+-------+-------+------+------+------+-------+