Pyspark数据帧:计算列中的唯一值,与其他列中的值独立地共同计算

时间:2018-02-22 18:08:29

标签: python python-2.7 pyspark spark-dataframe

我有一个火花数据框,包含数十亿个从各种来源获得的两种类型分子,调节因子和目标之间相互作用的预测(这些之间没有重叠)。我需要添加一列 包含数字资源,用于预测给定“监管者”的至少一个交互,并给予“目标”。

换句话说,对于每对'Regulator'和'Target',我试图获得包含'Regulator'和'Target'值的Source的数量,即使在一次交互中没有配对。

示例:

+---------+------+------+
|Regulator|Target|Source|
+---------+------+------+
|        m|     A|     x|
|        m|     B|     x|
|        m|     C|     z|
|        n|     A|     y|
|        n|     C|     x|
|        n|     C|     z|
+---------+------+------+

我想要获得的是:

+---------+------+------+----------+
|Regulator|Target|Source|No.sources|
+---------+------+------+----------+
|        m|     A|     x|         1|
|        m|     B|     x|         1|
|        m|     C|     z|         2|
|        n|     A|     y|         2|
|        n|     C|     x|         2|
|        n|     C|     z|         2|
+---------+------+------+----------+

进一步解释:

第一行(m, A, x)

  • 涉及m的相互作用由来源x和z预测。
  • 涉及A的相互作用由来源x和y预测。
  • 这些的重叠是x,因此No.sources等于1.

第二行(m, B, x)

  • 涉及m的相互作用由来源x和z预测。
  • 涉及B的交互仅由来源x预测。
  • 这些的重叠是x,因此No.sources等于1.

第三行(m, C, z)

  • 涉及m的相互作用由x和z
  • 预测
  • 涉及C的相互作用由来源x和z预测。
  • 这些的重叠是x,z,因此No.sources等于2.

1 个答案:

答案 0 :(得分:0)

以下是解决此问题的一种方法。对于每一行,创建2个新列:

  • 专栏'RS''Regulator'
  • 的来源集
  • 专栏'TS''Target'
  • 的来源集

然后你想要的输出是这些集合的交集长度。

考虑以下示例:

创建DataFrame

from pyspark.sql Window
import pyspark.sql.functions as f
cols = ["Regulator", "Target", "Source"]
data = [
    ('m', 'A', 'x'),
    ('m', 'B', 'x'),
    ('m', 'C', 'z'),
    ('n', 'A', 'y'),
    ('n', 'C', 'x'),
    ('n', 'C', 'z')
]

df = sqlCtx.createDataFrame(data, cols)

创建新列

使用pyspark.sql.functions.collect_set()pyspark.sql.Window计算'Source'列的不同值:

df = df.withColumn(
    'RS',
    f.collect_set(f.col('Source')).over(Window.partitionBy('Regulator'))
)

df = df.withColumn(
    'TS',
    f.collect_set(f.col('Source')).over(Window.partitionBy('Target'))
)
df.sort('Regulator', 'Target', 'Source').show()
#+---------+------+------+------+---------+
#|Regulator|Target|Source|    TS|       RS|
#+---------+------+------+------+---------+
#|        m|     A|     x|[y, x]|   [z, x]|
#|        m|     B|     x|   [x]|   [z, x]|
#|        m|     C|     z|[z, x]|   [z, x]|
#|        n|     A|     y|[y, x]|[y, z, x]|
#|        n|     C|     x|[z, x]|[y, z, x]|
#|        n|     C|     z|[z, x]|[y, z, x]|
#+---------+------+------+------+---------+

计算交叉点的长度

定义udf以返回两个集合的交集长度,并使用它来计算'No_sources'列。 (注意,我在列名中使用了_而不是.,因为它可以更轻松地使用select()。)

intersection_length_udf = f.udf(lambda u, v: len(set(u) & set(v)), IntegerType())

df = df.withColumn('No_sources', intersection_length_udf(f.col('TS'), f.col('RS')))

df.select('Regulator', 'Target', 'Source', 'No_sources')\
    .sort('Regulator', 'Target', 'Source')\
    .show()
#+---------+------+------+----------+
#|Regulator|Target|Source|No_sources|
#+---------+------+------+----------+
#|        m|     A|     x|         1|
#|        m|     B|     x|         1|
#|        m|     C|     z|         2|
#|        n|     A|     y|         2|
#|        n|     C|     x|         2|
#|        n|     C|     z|         2|
#+---------+------+------+----------+