我有一个数据框-
values = [('A',8),('B',7)]
df = sqlContext.createDataFrame(values,['col1','col2'])
df.show()
+----+----+
|col1|col2|
+----+----+
| A| 8|
| B| 7|
+----+----+
我希望从0到list
之间的{strong>偶数个col2
。
#Returns even numbers
def make_list(col):
return list(map(int,[x for x in range(col+1) if x % 2 == 0]))
make_list = udf(make_list)
df = df.withColumn('list',make_list(col('col2')))
df.show()
+----+----+---------------+
|col1|col2| list|
+----+----+---------------+
| A| 8|[0, 2, 4, 6, 8]|
| B| 7| [0, 2, 4, 6]|
+----+----+---------------+
df.printSchema()
root
|-- col1: string (nullable = true)
|-- col2: long (nullable = true)
|-- list: string (nullable = true)
我得到了想要的列表,但是列表的类型为string
而不是int
,如您在上面的printschema
中所见。
如何获得list
类型的int
?如果没有int
类型,则无法explode
此数据框。
关于如何获得list
的{{1}}的任何想法?
答案 0 :(得分:2)
您需要指定udf
的返回类型;要获得list
中的int
,请使用ArrayType(IntegerType())
:
from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, IntegerType
# specify the return type as ArrayType(IntegerType())
make_list_udf = udf(make_list, ArrayType(IntegerType()))
df = df.withColumn('list',make_list_udf(col('col2')))
df.show()
+----+----+------------+
|col1|col2| list|
+----+----+------------+
| A| 6|[0, 2, 4, 6]|
| B| 7|[0, 2, 4, 6]|
+----+----+------------+
df.printSchema()
root
|-- col1: string (nullable = true)
|-- col2: long (nullable = true)
|-- list: array (nullable = true)
| |-- element: integer (containsNull = true)
或者,如果您使用的是spark 2.4,则可以使用新的sequence
函数:
values = [('A',8),('B',7)]
df = sqlContext.createDataFrame(values,['col1','col2'])
from pyspark.sql.functions import sequence, lit, col
df.withColumn('list', sequence(lit(0), col('col2'), step=lit(2))).show()
+----+----+---------------+
|col1|col2| list|
+----+----+---------------+
| A| 8|[0, 2, 4, 6, 8]|
| B| 7| [0, 2, 4, 6]|
+----+----+---------------+
答案 1 :(得分:2)
事实证明,有一个closed form function会获得将所需的list
列中的数字连接起来所代表的数字。
我们可以实现此功能,然后仅使用API函数使用一些字符串操作和正则表达式来获得所需的输出。即使比较复杂,该应该仍比使用udf
更快。
import pyspark.sql.functions as f
def getEvenNumList(x):
n = f.floor(x/2)
return f.split(
f.concat(
f.lit("0,"),
f.regexp_replace(
(2./81.*(-9*n+f.pow(10, (n+1))-10)).cast('int').cast('string'),
r"(?<=\d)(?=\d)",
","
)
),
","
).cast("array<int>")
df = df.withColumn("list", getEvenNumList(f.col("col2")))
df.show()
#+----+----+---------------+
#|col1|col2| list|
#+----+----+---------------+
#| A| 8|[0, 2, 4, 6, 8]|
#| B| 7| [0, 2, 4, 6]|
#+----+----+---------------+
df.printSchema()
#root
# |-- col1: string (nullable = true)
# |-- col2: long (nullable = true)
# |-- list: array (nullable = true)
# | |-- element: integer (containsNull = true)
说明
所需列表中的元素数为1加col2
的底数除以2。(加1表示前导0
)。暂时忽略0
,将n
设为col2
的底数除以2。
如果您将列表中的数字连在一起(可以使用str.join
),则结果数字将由表达式给出:
2*sum(i*10**(n-i) for i in range(1,n+1))
使用Wolfram Alpha,您可以为此和计算一个闭合形式的方程。
一旦有了该数字,就可以将其转换为以0开头的字符串。
最后,我在每个数字之间添加了逗号作为分隔符,对结果进行了拆分,并将其转换为整数数组。