PySpark-在dataframe列中创建的列表的类型为String而不是Integer

时间:2019-01-23 17:11:06

标签: python list pyspark

我有一个数据框-

values = [('A',8),('B',7)]
df = sqlContext.createDataFrame(values,['col1','col2'])
df.show()
+----+----+
|col1|col2|
+----+----+
|   A|   8|
|   B|   7|
+----+----+

我希望从0到list之间的{strong>偶数个col2

#Returns even numbers
def make_list(col):
    return list(map(int,[x for x in range(col+1) if x % 2 == 0]))
make_list = udf(make_list)

df = df.withColumn('list',make_list(col('col2')))
df.show()
+----+----+---------------+
|col1|col2|           list|
+----+----+---------------+
|   A|   8|[0, 2, 4, 6, 8]|
|   B|   7|   [0, 2, 4, 6]|
+----+----+---------------+
df.printSchema()
root
 |-- col1: string (nullable = true)
 |-- col2: long (nullable = true)
 |-- list: string (nullable = true)

我得到了想要的列表,但是列表的类型为string而不是int,如您在上面的printschema中所见。

如何获得list类型的int?如果没有int类型,则无法explode此数据框。

关于如何获得list的{​​{1}}的任何想法?

2 个答案:

答案 0 :(得分:2)

您需要指定udf的返回类型;要获得list中的int,请使用ArrayType(IntegerType())

from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, IntegerType

# specify the return type as ArrayType(IntegerType())
make_list_udf = udf(make_list, ArrayType(IntegerType()))

df = df.withColumn('list',make_list_udf(col('col2')))
df.show()
+----+----+------------+                                                        
|col1|col2|        list|
+----+----+------------+
|   A|   6|[0, 2, 4, 6]|
|   B|   7|[0, 2, 4, 6]|
+----+----+------------+

df.printSchema()
root
 |-- col1: string (nullable = true)
 |-- col2: long (nullable = true)
 |-- list: array (nullable = true)
 |    |-- element: integer (containsNull = true)

或者,如果您使用的是spark 2.4,则可以使用新的sequence函数:

values = [('A',8),('B',7)]
df = sqlContext.createDataFrame(values,['col1','col2'])

from pyspark.sql.functions import sequence, lit, col
df.withColumn('list', sequence(lit(0), col('col2'), step=lit(2))).show()
+----+----+---------------+
|col1|col2|           list|
+----+----+---------------+
|   A|   8|[0, 2, 4, 6, 8]|
|   B|   7|   [0, 2, 4, 6]|
+----+----+---------------+

答案 1 :(得分:2)

事实证明,有一个closed form function会获得将所需的list列中的数字连接起来所代表的数字。

我们可以实现此功能,然后仅使用API​​函数使用一些字符串操作和正则表达式来获得所需的输出。即使比较复杂,该应该仍比使用udf更快。

import pyspark.sql.functions as f

def getEvenNumList(x):
    n = f.floor(x/2)
    return f.split(
        f.concat(
            f.lit("0,"), 
            f.regexp_replace(
                (2./81.*(-9*n+f.pow(10, (n+1))-10)).cast('int').cast('string'), 
                r"(?<=\d)(?=\d)", 
                ","
            )
        ),
        ","
    ).cast("array<int>")

df = df.withColumn("list", getEvenNumList(f.col("col2")))
df.show()
#+----+----+---------------+
#|col1|col2|           list|
#+----+----+---------------+
#|   A|   8|[0, 2, 4, 6, 8]|
#|   B|   7|   [0, 2, 4, 6]|
#+----+----+---------------+

df.printSchema()
#root
# |-- col1: string (nullable = true)
# |-- col2: long (nullable = true)
# |-- list: array (nullable = true)
# |    |-- element: integer (containsNull = true)

说明

所需列表中的元素数为1加col2的底数除以2。(加1表示前导0)。暂时忽略0,将n设为col2的底数除以2。

如果您将列表中的数字连在一起(可以使用str.join),则结果数字将由表达式给出:

2*sum(i*10**(n-i) for i in range(1,n+1))

使用Wolfram Alpha,您可以为此和计算一个闭合形式的方程。

一旦有了该数字,就可以将其转换为以0开头的字符串。

最后,我在每个数字之间添加了逗号作为分隔符,对结果进行了拆分,并将其转换为整数数组。