Pyspark比较所有键的可迭代列表并返回相同元素的计数

时间:2016-11-27 20:34:56

标签: python apache-spark pyspark

我正在使用Pyspark处理具有键值对的数据集,如下所示:

[(u'1', u'10'), (u'1', u'15'), (u'1', u'5'), (u'2', u'11'), (u'2', u'15'), (u'2', u'30'),  (u'3', u'10'), (u'3', u'5'), (u'3', u'11')]

数据可以解释为

1 => 10, 15, 5;  2 => 11, 15, 30;  3 => 10, 5, 11;

我正在尝试比较所有键的值,并找到相同值的计数。在样本数据中,键1和键2都有值15,输出应返回1;键1和3都有值10和5,输出应该是2 ......等等。

预期产出:

1, 2 => 1;  1, 3 => 2;  2, 3 => 1;

我的想法是按键对数据进行分组以获取可迭代列表并比较列表中的各个元素。

data1 = data.groupByKey()

data1.map(lambda x: (x[0], list(x[1]))).collect()

data1的输出:

[(u'1', [u'10', u'15', u'15']), (u'2', [u'11', u'15', u'30']), (u'3', [u'10', u'5', u'11'])]

我无法想出一种迭代列表并比较所有键的元素的方法。如果有人知道一种方法来进行迭代或对如何解决问题提出建议,我将不胜感激。提前谢谢!

1 个答案:

答案 0 :(得分:0)

我可以通过键循环获得结果:

from pyspark.sql import SparkSession

spark = SparkSession\
    .builder\
    .appName("common values")\
    .getOrCreate()
sc = spark.sparkContext

data = [(u'1', u'10'), (u'1', u'15'), (u'1', u'5'), (u'2', u'11'), (u'2', u'15'), (u'2', u'30'),  (u'3', u'10'), (u'3', u'5'), (u'3', u'11')]
rdd = sc.parallelize(data)
rdd = rdd.groupByKey() \
           .mapValues(list)
print "data:\n", rdd.collect()
keys = rdd.map(lambda l: l[0]).collect()
print "keys:", keys
result = sc.emptyRDD()
for k in keys:
    values = rdd.filter(lambda l: l[0] == k)\
                .map(lambda l: l[1])
    values = set(values.collect()[0])
    rdd = rdd.filter(lambda l: l[0] != k).persist()
    # only compare unique pairs
    common = rdd.map(lambda l: [(l[0], k), len(set(l[1]).intersection(values))])
    result = result.union(common)
print "results:", result.collect()

构造一组键,并为每个键:

  • 构建一组相应的值
  • 为每个键("使用的"键除外)检查多少个元素 是共同的

它适用于这个微小的测试数据集(spark 2.0.2,python 2.7),不确定它有多好 规模。

results: [[(u'3', u'1'), 2], [(u'2', u'1'), 1], [(u'2', u'3'), 1]]