将具有至少两次相同描述符的序列以相同顺序分组

时间:2018-07-16 16:09:49

标签: pyspark

我有以下数据框:

+--------+--------------------+
|      id|         description|
+--------+--------------------+
|14144206|(1.0, 0.0, 0.0, 0.0)|
|14144206|(0.0, 1.0, 0.0, 0.0)|
|19461601|(0.0, 0.0, 1.0, 0.0)|
|19461601|(0.0, 0.0, 0.0, 1.0)|
|34578543|(1.0, 0.0, 0.0, 0.0)|
|34578543|(0.0, 1.0, 0.0, 0.0)|
|45672467|(0.0, 1.0, 0.0, 0.0)|
|45672467|(0.0, 0.0, 1.0, 0.0)|
|45672467|(0.0, 0.0, 0.0, 1.0)|
+--------+--------------------+

可以通过以下代码获得:

df = sqlCtx.createDataFrame(
    [
        (14144206, '(1.0, 0.0, 0.0, 0.0)'),
        (14144206, '(0.0, 1.0, 0.0, 0.0)'),
        (19461601, '(0.0, 0.0, 1.0, 0.0)'),
        (19461601, '(0.0, 0.0, 0.0, 1.0)'),
        (34578543, '(1.0, 0.0, 0.0, 0.0)'),
        (34578543, '(0.0, 1.0, 0.0, 0.0)'),
        (45672467, '(0.0, 1.0, 0.0, 0.0)'),
        (45672467, '(0.0, 0.0, 1.0, 0.0)'),
        (45672467, '(0.0, 0.0, 0.0, 1.0)')
            ],
    ('id', 'description')
)

所需的输出是一个元组列表(长度可以为2、3,最大为length(描述)),其中每个元组都包含在其中出现2次,3次或4次等的id。列说明中的顺序相同。所以这里的输出应该是:

[(14144206, 34578543), (34578543, 45672467)]

第一步是对ID进行分组并加总描述,以获取以下数据帧:

+--------+--------------------+
|      id|         description|
+--------+--------------------+
|14144206|(1.0, 1.0, 0.0, 0.0)|
|19461601|(0.0, 0.0, 1.0, 1.0)|
|34578543|(1.0, 1.0, 0.0, 0.0)|
|45672467|(0.0, 1.0, 1.0, 1.0)|
+--------+--------------------+

但是,执行此指令时遇到以下错误:

df.groupBy("id").agg(sum("description").alias("sum_description"))

Error :"cannot resolve 'sum(`description`)' due to data type mismatch: function sum requires numeric types, not org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7

然后执行分组任务,我看不到是否有用于此目的的现有功能。

有关信息,我的真实数据框的列说明长度为35000,并且大约有15000个不同的ID

1 个答案:

答案 0 :(得分:0)

不确定我是否遵循逻辑,因为我得到的答案略有不同,但这是这样的:

# Find the cardinality of each description
description_cnt=df.groupby('description').agg(f.sum(f.lit(1)).alias('id_cnt'))
df=df.join(description_cnt, on='description')

# Group by description and count and gather the Id's into a list
df_id_grp=df.groupby('description','id_cnt')
             .agg(f.collect_list('id')
             .alias('grouped_id'))

# Filter down to count 2 and display
df_id_grp.filter(df_id_grp['id_cnt']==2)
         .select('grouped_id')
         .rdd.map(lambda x: tuple(x['grouped_id']))
         .take(20)

如您所见,答案与您的答案有所不同,但我认为逻辑与您所​​描述的相同。如果您可以审查并让我知道您的想法,我将作相应修改。