我正在使用Pyspark处理具有键值对的数据集,如下所示:
[(u'1', u'10'), (u'1', u'15'), (u'1', u'5'), (u'2', u'11'), (u'2', u'15'), (u'2', u'30'), (u'3', u'10'), (u'3', u'5'), (u'3', u'11')]
数据可以解释为
1 => 10, 15, 5; 2 => 11, 15, 30; 3 => 10, 5, 11;
我正在尝试比较所有键的值,并找到相同值的计数。在样本数据中,键1和键2都有值15,输出应返回1;键1和3都有值10和5,输出应该是2 ......等等。
预期产出:
1, 2 => 1; 1, 3 => 2; 2, 3 => 1;
我的想法是按键对数据进行分组以获取可迭代列表并比较列表中的各个元素。
data1 = data.groupByKey()
data1.map(lambda x: (x[0], list(x[1]))).collect()
data1的输出:
[(u'1', [u'10', u'15', u'15']), (u'2', [u'11', u'15', u'30']), (u'3', [u'10', u'5', u'11'])]
我无法想出一种迭代列表并比较所有键的元素的方法。如果有人知道一种方法来进行迭代或对如何解决问题提出建议,我将不胜感激。提前谢谢!
答案 0 :(得分:0)
我可以通过键循环获得结果:
from pyspark.sql import SparkSession
spark = SparkSession\
.builder\
.appName("common values")\
.getOrCreate()
sc = spark.sparkContext
data = [(u'1', u'10'), (u'1', u'15'), (u'1', u'5'), (u'2', u'11'), (u'2', u'15'), (u'2', u'30'), (u'3', u'10'), (u'3', u'5'), (u'3', u'11')]
rdd = sc.parallelize(data)
rdd = rdd.groupByKey() \
.mapValues(list)
print "data:\n", rdd.collect()
keys = rdd.map(lambda l: l[0]).collect()
print "keys:", keys
result = sc.emptyRDD()
for k in keys:
values = rdd.filter(lambda l: l[0] == k)\
.map(lambda l: l[1])
values = set(values.collect()[0])
rdd = rdd.filter(lambda l: l[0] != k).persist()
# only compare unique pairs
common = rdd.map(lambda l: [(l[0], k), len(set(l[1]).intersection(values))])
result = result.union(common)
print "results:", result.collect()
构造一组键,并为每个键:
它适用于这个微小的测试数据集(spark 2.0.2,python 2.7),不确定它有多好 规模。
results: [[(u'3', u'1'), 2], [(u'2', u'1'), 1], [(u'2', u'3'), 1]]