现在,我已经针对预期的行为提供了工作代码,但是它依赖于嵌套的df循环,并且在大型数据集上的性能非常差(在具有1亿行的数据帧上超过1小时) 。是否有更有效的方法来使用地图或使用for进行某种构造来执行循环,联接和嵌套循环?




df = spark.createDataFrame(
  [(1, 12, 'East', 'Q1'),
   (2, 14, 'East', 'Q1'),
   (3, 12, 'West', 'Q2'),
   (4, 13, 'West', 'Q2'),
   (5, 13, 'East', 'Q1'),
   (6, 12,  None,  None),
   (7, 13, 'West', None),
   (8, 12, 'West', None)],
  ['id', 'product', 'location', 'quarter'])

map_df = spark.createDataFrame(
  [(12, 'East', 'Q1', 10, 15),
   (13, 'East', 'Q1', 5,  10),
   (14, 'East', 'Q1', 20, 20),
   (13, 'West', 'Q1', 7,  8),
   (14, 'West', 'Q1', 10, 12),
   (12, 'East', 'Q2', 30, 5)],
  ['product', 'location', 'quarter', 'cost', 'quantity'])

group_dict = {
  1: ['product', 'location', 'quarter'],
  2: ['product', 'location'], 
  3: ['product', 'quarter'], 
  4: ['product']}


facts = ['cost', 'quantity']
h = 'hash'
field_list = ['product', 'location', 'quarter']

origin_df = df.withColumn(h, hash(concat_ws('|', *df.columns))) # Creates a unique ID to be used in the final join, so unnecessary columns aren't pulled into each step of the loop
loop_df   = origin_df.select(h, *field_list)

## Loops and groups each combination of fields in each `combo`, while aggregating each `fact`
for k, c in group_dict.items():
  exp     = [avg(f).alias(f"{f}_coalesce") for f in facts]
  join_df = map_df # The function will repeatedly use `map_df` as the source for the fact being estimated, rather than an existing field in the original `df` in its current `looped` iteration

  loop_df = loop_df.join(join_df.groupBy(c).agg(*exp), c, 'left')
  for f in facts:
    f_co    = f"{f}_coalesce"
    if not f in loop_df.columns: # Creates a blank version of the field `f` as the first pass, so it can be coalesced with `when` in subsequent loops
      loop_df = loop_df.withColumn(f, col(f_co)).drop(f_co)
      loop_df = loop_df.withColumn(f, when(col(f).isNull(), col(f_co)).otherwise(col(f))).drop(f_co) # If `f` is blank, use the value from the `.groupBy` instead

return_df = loop_df.select(h, *facts).join(origin_df.drop(*facts), h).drop(h)


import pandas as pd

df = pd.DataFrame([
    (1, 12, 'East', 'Q1'), (2, 14, 'East', 'Q1'), (3, 12, 'West', 'Q2'),
    (4, 13, 'West', 'Q2'), (5, 13, 'East', 'Q1'), (6, 12,  None,  None),
    (7, 13, 'West', None), (8, 12, 'West', None)],
    columns = ['id', 'product', 'location', 'quarter'])

map_df = pd.DataFrame([
    (12, 'East', 'Q1', 10, 15), (13, 'East', 'Q1', 5,  10), (14, 'East', 'Q1', 20, 20),
    (13, 'West', 'Q1', 7,  8),  (14, 'West', 'Q1', 10, 12), (12, 'East', 'Q2', 30, 5)],
    columns = ['product', 'location', 'quarter', 'cost', 'quantity'])

group_dict = {
  1: ['product', 'location', 'quarter'], 2: ['product', 'location'], 
  3: ['product', 'quarter'],             4: ['product']}


ts = list()
for rule, fields in group_dict.items():
    t = pd.merge(left = df, right = map_df, how = 'outer', on = fields)
    t['rule'] = rule
    t['fields'] = str(fields)



# concatenate data frames, and drop rows with NaN values
new_df = (pd.concat(ts)
          .loc[:, ['id', 'product', 'location', 'quarter', 
                   'cost', 'quantity', 'rule']]
          .loc[ lambda x: x['id'].notna() ]
          .loc[ lambda x: x['cost'].notna() ]
          .loc[ lambda x: x['quantity'].notna() ]

# find the first rule that has id, cost and quantity
new_df['1st_match'] = new_df.groupby('id')['rule'].transform(min)

# keep first (i.e., best) match
new_df = new_df[ new_df['rule'] == new_df['1st_match'] ]

# drop helper columns
new_df = new_df.drop(columns=['rule', '1st_match'])

# fields to keep
fields = ['id', 'product', 'location', 'quarter']

# use average cost, average quantity (some rows have more that one match)
new_df = new_df.groupby(fields, as_index=False, dropna=False)[
         ['cost', 'quantity']].mean()


    id  product location quarter  cost  quantity
0  1.0       12     East      Q1  10.0      15.0
1  2.0       14     East      Q1  20.0      20.0
2  3.0       12     West      Q2  30.0       5.0
3  4.0       13     West      Q2   7.0       8.0
4  5.0       13     East      Q1   5.0      10.0
5  6.0       12      NaN     NaN  20.0      10.0
6  7.0       13     West     NaN   7.0       8.0
7  8.0       12     West     NaN  20.0      10.0