尝试在PySpark UDF中创建新列,但该值为空!
data_list = [['a', [1, 2, 3]], ['b', [4, 5, 6]],['c', [2, 4, 6, 8]],['d', [4, 1]],['e', [1,2]]]
all_cols = ['COL1','COL2']
df = sqlContext.createDataFrame(data_list, all_cols)
df.show()
+----+------------+
|COL1| COL2|
+----+------------+
| a| [1, 2, 3]|
| b| [4, 5, 6]|
| c|[2, 4, 6, 8]|
| d| [4, 1]|
| e| [1, 2]|
+----+------------+
df.printSchema()
root
|-- COL1: string (nullable = true)
|-- COL2: array (nullable = true)
| |-- element: long (containsNull = true)
def cr_pair(idx_src, idx_dest):
idx_dest.append(idx_dest.pop(0))
return idx_src, idx_dest
lst1 = [1,2,3]
lst2 = [1,2,3]
cr_pair(lst1, lst2)
([1, 2, 3], [2, 3, 1])
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
from pyspark.sql.types import ArrayType
get_idx_pairs = udf(lambda x: cr_pair(x, x), ArrayType(IntegerType()))
df = df.select('COL1', 'COL2', get_idx_pairs('COL2').alias('COL3'))
df.printSchema()
root
|-- COL1: string (nullable = true)
|-- COL2: array (nullable = true)
| |-- element: long (containsNull = true)
|-- COL3: array (nullable = true)
| |-- element: integer (containsNull = true)
df.show()
+----+------------+------------+
|COL1| COL2| COL3|
+----+------------+------------+
| a| [1, 2, 3]|[null, null]|
| b| [4, 5, 6]|[null, null]|
| c|[2, 4, 6, 8]|[null, null]|
| d| [4, 1]|[null, null]|
| e| [1, 2]|[null, null]|
+----+------------+------------+
问题出在哪里。 我在COL3列中获取所有值'null'。 预期的结果应该是:
+----+------------+----------------------------+
|COL1| COL2| COL3|
+----+------------+----------------------------+
| a| [1, 2, 3]|[[1 ,2, 3], [2, 3, 1]] |
| b| [4, 5, 6]|[[4, 5, 6], [5, 6, 4]] |
| c|[2, 4, 6, 8]|[[2, 4, 6, 8], [4, 6, 8, 2]]|
| d| [4, 1]|[[4, 1], [1, 4]] |
| e| [1, 2]|[[1, 2], [2, 1]] |
+----+------------+----------------------------+
答案 0 :(得分:1)
您的UDF应该返回ArrayType(ArrayType(IntegerType()))
,因为您期望列中的列表为一个列表,除了它只需要一个参数即可:
def cr_pair(idx_src):
return idx_src, idx_src[1:] + idx_src[:1]
get_idx_pairs = udf(cr_pair, ArrayType(ArrayType(IntegerType())))
df.withColumn('COL3', get_idx_pairs(df['COL2'])).show(5, False)
+----+------------+----------------------------+
|COL1|COL2 |COL3 |
+----+------------+----------------------------+
|a |[1, 2, 3] |[[2, 3, 1], [2, 3, 1]] |
|b |[4, 5, 6] |[[5, 6, 4], [5, 6, 4]] |
|c |[2, 4, 6, 8]|[[4, 6, 8, 2], [4, 6, 8, 2]]|
|d |[4, 1] |[[1, 4], [1, 4]] |
|e |[1, 2] |[[2, 1], [2, 1]] |
+----+------------+----------------------------+
答案 1 :(得分:1)
似乎您想做的就是循环移动列表中的元素。这是使用pyspark.sql.functions.posexplode()
(Spark 2.1及更高版本)的非udf方法:
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy("COL1", "COL2").orderBy(f.col("pos") == 0, "pos")
df = df.select("*", f.posexplode("COL2"))\
.select("COL1", "COL2", "pos", f.collect_list("col").over(w).alias('COL3'))\
.where("pos = 0")\
.drop("pos")\
.withColumn("COL3", f.array("COL2", "COL3"))
df.show(truncate=False)
#+----+------------+----------------------------------------------------+
#|COL1|COL2 |COL3 |
#+----+------------+----------------------------------------------------+
#|a |[1, 2, 3] |[WrappedArray(1, 2, 3), WrappedArray(2, 3, 1)] |
#|b |[4, 5, 6] |[WrappedArray(4, 5, 6), WrappedArray(5, 6, 4)] |
#|c |[2, 4, 6, 8]|[WrappedArray(2, 4, 6, 8), WrappedArray(4, 6, 8, 2)]|
#|d |[4, 1] |[WrappedArray(4, 1), WrappedArray(1, 4)] |
#|e |[1, 2] |[WrappedArray(1, 2), WrappedArray(2, 1)] |
#+----+------------+----------------------------------------------------+
使用posexplode
将返回两列-列表中的位置(pos
)和值(col
)。这里的技巧是我们先按f.col("pos") == 0
然后按"pos"
排序。这样会将数组中的第一个位置移到列表的末尾。
尽管此输出 prints 与使用python中的列表列表所期望的输出方式不同,但COL3
的内容确实是整数列表的列表。
df.printSchema()
#root
# |-- COL1: string (nullable = true)
# |-- COL2: array (nullable = true)
# | |-- element: long (containsNull = true)
# |-- COL3: array (nullable = false)
# | |-- element: array (containsNull = true)
# | | |-- element: long (containsNull = true)
更新
“ WrappedArray
前缀”正是Spark打印嵌套列表的方式。基础数组正是您所需要的。一种验证方法是调用collect()
并检查数据:
results = df.collect()
print([(r["COL1"], r["COL3"]) for r in results])
#[(u'a', [[1, 2, 3], [2, 3, 1]]),
# (u'b', [[4, 5, 6], [5, 6, 4]]),
# (u'c', [[2, 4, 6, 8], [4, 6, 8, 2]]),
# (u'd', [[4, 1], [1, 4]]),
# (u'e', [[1, 2], [2, 1]])]
或者如果您将df
转换为熊猫DataFrame:
print(df.toPandas())
# COL1 COL2 COL3
#0 a [1, 2, 3] ([1, 2, 3], [2, 3, 1])
#1 b [4, 5, 6] ([4, 5, 6], [5, 6, 4])
#2 c [2, 4, 6, 8] ([2, 4, 6, 8], [4, 6, 8, 2])
#3 d [4, 1] ([4, 1], [1, 4])
#4 e [1, 2] ([1, 2], [2, 1])