通过汇总列表过滤Spark RDD的更有效方法

时间:2017-03-16 19:08:49

标签: python r apache-spark pyspark

我有一个RDD数据,我在进行R代码的Spark管道()进行一些复杂的分层建模之前,正在进行一些预处理/数据调整。我有一些有用的东西,但看起来很笨拙而效率低下。我很想从这个问题中得到两件事,因为我还在学习如何使用Python中的Spark数据结构,已经在R工作了很长时间。

  1. 是否有更有效的方法来执行此代码的特定部分以提高效率?
  2. 与使用RDD相比,将所有这些作为SparkDataFrame更有效吗?
  3. 代码

    我首先拉动我的文本文件,通过一个附加stixel_id的R脚本来管道它,这将是我稍后将用作groupByKey()中的键。我的目标是按obs对另一列(stixel_id)求和,然后按stixel_id的子集过滤整个RDD。

    首先,拉入数据,通过R管道附加stixel_id并稍微捣乱(主要是由于遗留格式化)。

    # input format
    # u'83278,"train",-76.5018492,42.4622906,2012.8606557377,2012,315,9.25,2,0.805,2,0,112,10,28.2051,17.9487,0,0,10.2564,33.3333,0,0,0,0,2.5641,5.1282,2.5641,0,0'
    
    inp = sc.textFile("ebird.abund_species_all.random.merge.txt", minPartitions=numP*32) \
        .pipe("./0_stixel_keys_SparkHPCWrapper.R") \
        .map(lambda line: line.split("\n")) \
        .filter(lambda p:  p[0].split("\t")[0] != "") \
        .map(lambda p: (p[0].split("\t")[0]+","+p[0].split("\t")[1]))
    
    # output format
    # u'1-10-14-8,83278,"train",-76.5018492,42.4622906,0.8606557,2012,315,9.25,2,0.805,2,0,112,10,28.2051,17.9487,0,0,10.2564,33.3333,0,0,0,0,2.5641,5.1282,2.5641,0,0'
    

    接下来,使用我的RDD中的列子集创建一个DataFrame。

    parts = inp.map(lambda l: l.split(","))
    rows = parts.map(lambda p: Row(stixel_id=p[0], data_type=p[2], obs=p[29]))
    rowsdf = s1.createDataFrame(rows)
    
    # output format
    # Row(data_type=u'"train"', obs=u'0', stixel_id=u'1-10-14-8')
    

    接下来,将obs字段转换为整数,并将obsdata_type

    rowsdfc = rowsdf.select(rowsdf.stixel_id, rowsdf.data_type, rowsdf.obs.cast("integer"))
    sums = rowsdfc.filter(rowsdfc.data_type.like("%train%")) \
        .groupBy(["stixel_id", "data_type"]) \
        .sum("obs") \
        .collect()
    sumsdf = s1.createDataFrame(sums)
    
    # output format
    # Row(stixel_id=u'4-10-14-8', data_type=u'"train"', sum(obs)=5)
    

    倒数第二,按stixel_id过滤sum(obs) > 5列表并创建一个简单列表。

    sumsdfpos = sumsdf.filter(sumsdf["sum(obs)"] > 5)
    stixels = sumsdfpos.select(sumsdfpos.stixel_id)
    stixel_list = stixels.select("stixel_id").rdd.map(lambda x: x[0]).collect()
    fit_tasks = len(stixel_list)
    
    # output
    # [u'2-10-13-7', u'3-11-13-8', u'1-10-13-8', u'3-10-12-7', u'3-10-13-6', u'4-11-14-7', u'4-11-13-9', u'3-11-13-6', u'2-10-13-8', u'2-11-13-7', u'2-10-13-9', u'1-10-13-7', u'4-11-14-8', u'5-11-12-7', u'1-10-14-7', u'5-11-13-8', u'1-9-13-7', u'2-10-13-6', u'1-10-14-8', u'2-11-13-8', u'1-10-13-9', u'5-11-13-7', u'5-10-12-7', u'3-11-12-7', u'2-11-13-6', u'3-10-13-7', u'4-11-13-7', u'5-10-13-7', u'4-10-14-7', u'4-11-13-8', u'3-10-12-6', u'4-10-13-7', u'1-9-14-7', u'5-11-12-8', u'3-11-12-6', u'3-11-13-7']
    

    最后,在此列表中使用stixel_id过滤我的RDD。

    inpf = inp.filter(lambda p: p.split(",")[0] in stixel_list)
    
    # output format
    # u'1-10-14-8,83278,"train",-76.5018492,42.4622906,0.8606557,2012,315,9.25,2,0.805,2,0,112,10,28.2051,17.9487,0,0,10.2564,33.3333,0,0,0,0,2.5641,5.1282,2.5641,0,0'
    

0 个答案:

没有答案