在pySpark中获取每组的前N个项目

时间:2016-07-28 02:21:26

标签: python apache-spark

我使用Spark 1.6.2,我有以下数据结构:

sample = sqlContext.createDataFrame([
                    (1,['potato','orange','orange']),
                    (1,['potato','orange','yogurt']),
                    (2,['vodka','beer','vodka']),
                    (2,['vodka','beer','juice', 'vinegar'])

    ],['cat','terms'])

我想提取每只猫最常用的前N项。我开发了以下似乎有效的解决方案,但是我想看看是否有更好的方法来做到这一点。

from collections import Counter
def get_top(it, terms=200):
    c = Counter(it.__iter__())
    return [x[0][1] for x in c.most_common(terms)]

( sample.select('cat',sf.explode('terms')).rdd.map(lambda x: (x.cat, x.col))
 .groupBy(lambda x: x[0])
 .map(lambda x: (x[0], get_top(x[1], 2)))
 .collect()
)

它提供以下输出:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])]

这符合我的要求,但我真的不喜欢我使用Counter的事实。我怎么能单独用火花来做呢?

由于

1 个答案:

答案 0 :(得分:2)

如果这样做,最好将其发布到Code Review

正如练习一样,我没有使用计数器,但很大程度上你只是在复制相同的功能。

  • 计算(termcat
  • 的每次出现次数
  • 分组2
  • 根据Count和slice将值排序为术语数(from operator import add (sample.select('cat', sf.explode('terms')) .rdd .map(lambda x: (x, 1)) .reduceByKey(add) .groupBy(lambda x: x[0][0]) .mapValues(lambda x: [r[1] for r, _ in sorted(x, key=lambda a: -a[1])[:2]]) .collect())

代码:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])]

输出:

triple2