根据指定位置拆分列

时间:2018-05-17 05:25:21

标签: apache-spark pyspark apache-spark-sql

假设我的数据帧如下,我想以有效的方式基于位置拆分Col1

df = sc.parallelize([['sdbsajkdbnasjdh'],['sdahasdbasjda']]).toDF(['Col1'])

+---------------+
|           Col1|
+---------------+
|sdbsajkdbnasjdh|
|  sdahasdbasjda|
+---------------+

pos = [(1,2),(3,5),(7,10)]

例如,基于pos列表我想要如下结果集:

d|sa|dbn
d|ha|bas

我想要有效的方法来分割数据。

我目前的解决方案如下,但如果我使用EofError提供更长的pos列表(如果列表中有10个元组),则失败。

udf1 = udf(lambda x:  "|".join(str(x) for x in [x[j[0]:j[1]] for j in pos]),StringType())
final_df = df.withColumn("Split",udf1('Col1'))

3 个答案:

答案 0 :(得分:1)

我们可以使用pyspark函数中的substr()来为每个子字符串获取单独的列。然后使用UDF进行rowwise组合以加入列.Have尝试了下面的代码,

from pyspark.sql import functions as F

udf1 = F.udf(lambda x : '|'.join(x))

df = df.withColumn('Col2',udf1(F.struct([df.Col1.substr(j[0]+1,j[1]-j[0]) for j in pos])))
+---------------+--------+
|           Col1|    Col2|
+---------------+--------+
|sdbsajkdbnasjdh|d|sa|dbn|
|  sdahasdbasjda|d|ha|bas|
+---------------+--------+

当substr()取startpos和length时,我们计算长度。 希望这有帮助。!

答案 1 :(得分:0)

也许仅仅翻过一次字符串就可以提高效率。

(假设“pos”的当前和正确结构):

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

df = sc.parallelize([['sdbsajkdbngdfgdfgdfgdfgdgdfgdgdfgdgdgsfhhfghfghdfghdfasjdh'],\
                     ['sdahasdbsdfgsdfhhsfghhkghmaeserewyuouip,mfhdfbdfgbdfgjdtsagasjda']]).toDF(['Col1'])
pos = [(1,2),(3,5),(7,10),(11,15),(20,21),(22,24),(24,26),(26,30),(35,38),(38,42),(42,43),(45,50)]

flat_pos = [x for y in pos for x in y]

def split_function(row):
    return_list = []
    start_copy = False
    flat_pos_index = 0
    temp_list = []
    max_pos = flat_pos[-1]
    i = 0
    while i < len(row):
        if i > max_pos:
            break
        elif i == flat_pos[flat_pos_index]: 
            if flat_pos_index % 2 == 0:
                start_copy = True
            else:
                start_copy = False
                return_list.append("".join(temp_list))
                temp_list = []
            flat_pos_index += 1
        if start_copy:
            temp_list.append(row[i])
        if flat_pos_index + 1 < len(flat_pos) and flat_pos[flat_pos_index] == flat_pos[flat_pos_index + 1]:
                flat_pos_index += 1
                start_copy = False
                return_list.append("".join(temp_list))
                temp_list = []
        else:
            i += 1

    return "|".join(return_list)


udf2 = udf(split_function ,StringType())
final_df = df.withColumn("Split",udf2('Col1'))

final_df.collect()
  

[行(Col1中= u'sdbsajkdbngdfgdfgdfgdfgdgdfgdgdfgdgdgsfhhfghfghdfghdfasjdh”,   分裂= u'd | SA | DBN | DFGD | d | G | G | fgdg | d | F | G | ghdfg'),   行(Col1中= u'sdahasdbsdfgsdfhhsfghhkghmaeserewyuouip,mfhdfbdfgbdfgjdtsagasjda”,   分裂= u'd |公顷| BSD | GSDF | H | K | H | AESE | O | P | H | bdfgb')]

答案 2 :(得分:0)

您不需要使用udf

相反,您可以结合使用pyspark.sql.functions.arraypyspark.sql.functions.substring对元组使用列表推导来获得所需的子字符串。

请注意,substring()的第一个参数将字符串的开头视为索引1,因此我们传入start+1。第二个参数是字符串长度,所以我传递(stop-start)

import pyspark.sql.functions as f
df.withColumn(
    'Split',
    f.array(
        [
            f.substring(
                str=f.col('Col1'),
                pos=start+1,
                len=(stop-start)
            ) 
            for start, stop in pos
        ]
    )
).show()
#+---------------+------------+
#|           Col1|       Split|
#+---------------+------------+
#|sdbsajkdbnasjdh|[d, sa, dbn]|
#|  sdahasdbasjda|[d, ha, bas]|
#+---------------+------------+

要使用"|"将这些内容连接在一起,请使用pyspark.sql.functions.concat_ws将调用包裹到array()

df.withColumn(
    'Split',
    f.concat_ws(
        "|",
        f.array(
            [
                f.substring(f.col('Col1'), start+1, (stop-start)) 
                for start, stop in pos
            ]
        )
    )
).show()
#+---------------+--------+
#|           Col1|   Split|
#+---------------+--------+
#|sdbsajkdbnasjdh|d|sa|dbn|
#|  sdahasdbasjda|d|ha|bas|
#+---------------+--------+

使用DataFrame函数比使用udf更快。