使用UDF创建PySpark DF列以模仿numpy的np.roll函数

时间:2018-07-23 19:39:55

标签: apache-spark pyspark apache-spark-sql user-defined-functions

尝试在PySpark UDF中创建新列,但该值为空!

创建DF

data_list = [['a', [1, 2, 3]], ['b', [4, 5, 6]],['c', [2, 4, 6, 8]],['d', [4, 1]],['e', [1,2]]]
all_cols = ['COL1','COL2']
df = sqlContext.createDataFrame(data_list, all_cols)
df.show()
+----+------------+
|COL1|        COL2|
+----+------------+
|   a|   [1, 2, 3]|
|   b|   [4, 5, 6]|
|   c|[2, 4, 6, 8]|
|   d|      [4, 1]|
|   e|      [1, 2]|
+----+------------+

df.printSchema()
root
 |-- COL1: string (nullable = true)
 |-- COL2: array (nullable = true)
 |    |-- element: long (containsNull = true)

创建函数

def cr_pair(idx_src, idx_dest):
    idx_dest.append(idx_dest.pop(0))
    return idx_src, idx_dest
lst1 = [1,2,3]
lst2 = [1,2,3]
cr_pair(lst1, lst2)
([1, 2, 3], [2, 3, 1])

创建并注册UDF

from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

from pyspark.sql.types import ArrayType
get_idx_pairs = udf(lambda x: cr_pair(x, x), ArrayType(IntegerType()))

向DF添加新列

df = df.select('COL1', 'COL2',  get_idx_pairs('COL2').alias('COL3'))
df.printSchema()
root
 |-- COL1: string (nullable = true)
 |-- COL2: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- COL3: array (nullable = true)
 |    |-- element: integer (containsNull = true)

df.show()
+----+------------+------------+
|COL1|        COL2|        COL3|
+----+------------+------------+
|   a|   [1, 2, 3]|[null, null]|
|   b|   [4, 5, 6]|[null, null]|
|   c|[2, 4, 6, 8]|[null, null]|
|   d|      [4, 1]|[null, null]|
|   e|      [1, 2]|[null, null]|
+----+------------+------------+

问题出在哪里。 我在COL3列中获取所有值'null'。 预期的结果应该是:

+----+------------+----------------------------+
|COL1|        COL2|                        COL3|
+----+------------+----------------------------+
|   a|   [1, 2, 3]|[[1 ,2, 3], [2, 3, 1]]      |
|   b|   [4, 5, 6]|[[4, 5, 6], [5, 6, 4]]      |
|   c|[2, 4, 6, 8]|[[2, 4, 6, 8], [4, 6, 8, 2]]|
|   d|      [4, 1]|[[4, 1], [1, 4]]            |
|   e|      [1, 2]|[[1, 2], [2, 1]]            |
+----+------------+----------------------------+

2 个答案:

答案 0 :(得分:1)

您的UDF应该返回ArrayType(ArrayType(IntegerType())),因为您期望列中的列表为一个列表,除了它只需要一个参数即可:

def cr_pair(idx_src):
    return idx_src, idx_src[1:] + idx_src[:1]

get_idx_pairs = udf(cr_pair, ArrayType(ArrayType(IntegerType())))
df.withColumn('COL3', get_idx_pairs(df['COL2'])).show(5, False)
+----+------------+----------------------------+
|COL1|COL2        |COL3                        |
+----+------------+----------------------------+
|a   |[1, 2, 3]   |[[2, 3, 1], [2, 3, 1]]      |
|b   |[4, 5, 6]   |[[5, 6, 4], [5, 6, 4]]      |
|c   |[2, 4, 6, 8]|[[4, 6, 8, 2], [4, 6, 8, 2]]|
|d   |[4, 1]      |[[1, 4], [1, 4]]            |
|e   |[1, 2]      |[[2, 1], [2, 1]]            |
+----+------------+----------------------------+

答案 1 :(得分:1)

似乎您想做的就是循环移动列表中的元素。这是使用pyspark.sql.functions.posexplode()(Spark 2.1及更高版本)的非udf方法:

import pyspark.sql.functions as f
from pyspark.sql import Window

w = Window.partitionBy("COL1", "COL2").orderBy(f.col("pos") == 0, "pos")
df = df.select("*", f.posexplode("COL2"))\
    .select("COL1", "COL2", "pos", f.collect_list("col").over(w).alias('COL3'))\
    .where("pos = 0")\
    .drop("pos")\
    .withColumn("COL3", f.array("COL2", "COL3"))

df.show(truncate=False)
#+----+------------+----------------------------------------------------+
#|COL1|COL2        |COL3                                                |
#+----+------------+----------------------------------------------------+
#|a   |[1, 2, 3]   |[WrappedArray(1, 2, 3), WrappedArray(2, 3, 1)]      |
#|b   |[4, 5, 6]   |[WrappedArray(4, 5, 6), WrappedArray(5, 6, 4)]      |
#|c   |[2, 4, 6, 8]|[WrappedArray(2, 4, 6, 8), WrappedArray(4, 6, 8, 2)]|
#|d   |[4, 1]      |[WrappedArray(4, 1), WrappedArray(1, 4)]            |
#|e   |[1, 2]      |[WrappedArray(1, 2), WrappedArray(2, 1)]            |
#+----+------------+----------------------------------------------------+

使用posexplode将返回两列-列表中的位置(pos)和值(col)。这里的技巧是我们先按f.col("pos") == 0然后按"pos"排序。这样会将数组中的第一个位置移到列表的末尾。

尽管此输出 prints 与使用python中的列表列表所期望的输出方式不同,但COL3的内容确实是整数列表的列表。

df.printSchema()
#root
# |-- COL1: string (nullable = true)
# |-- COL2: array (nullable = true)
# |    |-- element: long (containsNull = true)
# |-- COL3: array (nullable = false)
# |    |-- element: array (containsNull = true)
# |    |    |-- element: long (containsNull = true)

更新

WrappedArray前缀”正是Spark打印嵌套列表的方式。基础数组正是您所需要的。一种验证方法是调用collect()并检查数据:

results = df.collect()
print([(r["COL1"], r["COL3"]) for r in results])
#[(u'a', [[1, 2, 3], [2, 3, 1]]),
# (u'b', [[4, 5, 6], [5, 6, 4]]),
# (u'c', [[2, 4, 6, 8], [4, 6, 8, 2]]),
# (u'd', [[4, 1], [1, 4]]),
# (u'e', [[1, 2], [2, 1]])]

或者如果您将df转换为熊猫DataFrame:

print(df.toPandas())
#  COL1          COL2                          COL3
#0    a     [1, 2, 3]        ([1, 2, 3], [2, 3, 1])
#1    b     [4, 5, 6]        ([4, 5, 6], [5, 6, 4])
#2    c  [2, 4, 6, 8]  ([2, 4, 6, 8], [4, 6, 8, 2])
#3    d        [4, 1]              ([4, 1], [1, 4])
#4    e        [1, 2]              ([1, 2], [2, 1])