爆炸数组 - (Dataframe)pySpark

时间:2016-10-18 12:52:01

标签: python apache-spark pyspark spark-dataframe

我有一个这样的数据框:

  +-----+--------------------+
|index|              merged|
+-----+--------------------+
|    0|[[2.5, 2.4], [3.5...|
|    1|[[-1.0, -1.0], [-...|
|    2|[[-1.0, -1.0], [-...|
|    3|[[0.0, 0.0], [0.5...|
|    4|[[0.5, 0.5], [1.0...|
|    5|[[0.5, 0.5], [1.0...|
|    6|[[-1.0, -1.0], [0...|
|    7|[[0.0, 0.0], [0.5...|
|    8|[[0.5, 0.5], [1.0...|
+-----+--------------------+

我想将合并的列分解为

+-----+-------+-------+
|index|Column1|Column2|
+-----+-------+-------+
|    0|    2.5|   2.4 |
|    1|    3.5|    0.5|
|    2|   -1.0|   -1.0|
|    3|   -1.0|   -1.0|
|    4|   0.0 |   0.0 |
|    5|    0.5|   0.74|
+-----+-------+-------+

每个元组[[2.5,2.4],[3.5,0,5]]重新填充两列,知道2,5和3,5将存储在第1列中,并且(2.4,0,5)将被存储在第二栏

所以我试过这个

df= df.withColumn("merged", df["merged"].cast("array<array<float>>"))
df= df.withColumn("merged",explode('merged'))

然后我将应用udf来创建另一个DF

但是我无法投射数据或应用爆炸,我收到了错误

pyspark.sql.utils.AnalysisException: u"cannot resolve 'cast(merged as array<array<float>)' due to data type mismatch: cannot cast StringType to ArrayType(StringType,true)

我也尝试了

df= df.withColumn("merged", df["merged"].cast("array<string>"))

但没有任何作用 如果我在没有演员的情况下申请爆炸,我会收到

pyspark.sql.utils.AnalysisException: u"cannot resolve 'explode(merged)' due to data type mismatch: input to function explode should be array or map type, not StringType;

1 个答案:

答案 0 :(得分:0)

您可以尝试以下代码:

from pyspark import SparkConf, SparkContext                        
from pyspark.sql import SparkSession                               

from pyspark.sql.types import FloatType, StringType, IntegerType   
from pyspark.sql.functions import udf, col                         


def col1_calc(merged):                                             
    return merged[0][0]                                            

def col2_calc(merged):                                             
    return merged[0][1]                                            

if __name__ == '__main__':                                         
    spark = SparkSession \                                         
        .builder \                                                 
        .appName("Python Spark SQL Hive integration example") \    
        .getOrCreate()                                             

    df = spark.createDataFrame([                                   
        (0, [[2.5,2.4],[3.5]]),                                    
        (1, [[-1.0,-1.0],[3.5]]),                                  
        (2, [[-1.0,-1.0],[3.5]]),                                  
    ], ["index", "merged"])                                        

    df.show()                                                      

    column1_calc = udf(col1_calc, FloatType())                     
    df = df.withColumn('Column1', column1_calc(df['merged']))      
    column2_calc = udf(col2_calc, FloatType())                     
    df = df.withColumn('Column2', column2_calc(df['merged']))      

    df = df.select(['Column1', 'Column2', 'index'])                
    df.show()         

输出:

+-------+-------+-----+
|Column1|Column2|index|
+-------+-------+-----+
|    2.5|    2.4|    0|
|   -1.0|   -1.0|    1|
|   -1.0|   -1.0|    2|
+-------+-------+-----+