如何展平df?

时间:2019-05-26 16:13:23

标签: dataframe pyspark

我在spark中有该架构DF,我想使用“ def flatten_df”函数对其进行平坦化,但是输出仍然相同吗?

我的数据框架构如下所示

 Selected_duration_df.printSchema()

输出

  root
   |-- Duration: long (nullable = true)
   |-- event_end_date: timestamp (nullable = true)
   |-- event_start_date: timestamp (nullable = true)
   |-- location_id: long (nullable = true)
   |-- location_name: string (nullable = true)
   |-- product_id: string (nullable = true)
   |-- sensor_id: long (nullable = true)
   |-- sensor_name: string (nullable = true)
   |-- fault_fault_code: string (nullable = true)
   |-- fault_fault_description: string (nullable = true)
   |-- product_model_features: array (nullable = true)
   |    |-- element: struct (containsNull = true)
   |    |    |-- key: string (nullable = true)
   |    |    |-- value: string (nullable = true)

我尝试了“ def flatten_df”功能

 def flatten_df(nested_df, layers):
     flat_cols = []
     nested_cols = []
     flat_df = []

     flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
     nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])

     flat_df.append(nested_df.select(flat_cols[0] +
                           [col(nc+'.'+c).alias(nc+'_'+c)
                            for nc in nested_cols[0]
                            for c in nested_df.select(nc+'.*').columns])
              )
     for i in range(1, layers):
         print (flat_cols[i-1])
         flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
         nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])

         flat_df.append(flat_df[i-1].select(flat_cols[i] +
                            [col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols[i]
                                for c in flat_df[i-1].select(nc+'.*').columns])
    )

     return flat_df[-1]


  my_flattened_df = flatten_df(Selected_duration_df, 3)

输出相同       my_flattened_df.printSchema() 输出

 root
  |-- Duration: long (nullable = true)
  |-- event_end_date: timestamp (nullable = true)
  |-- event_start_date: timestamp (nullable = true)
  |-- location_id: long (nullable = true)
  |-- location_name: string (nullable = true)
  |-- product_id: string (nullable = true)
  |-- sensor_id: long (nullable = true)
  |-- sensor_name: string (nullable = true)
  |-- fault_fault_code: string (nullable = true)
  |-- fault_fault_description: string (nullable = true)
  |-- product_model_features: array (nullable = true)
  |    |-- element: struct (containsNull = true)
  |    |    |-- key: string (nullable = true)
  |    |    |-- value: string (nullable = true)

1 个答案:

答案 0 :(得分:0)

您可以使用spark explode函数简化此操作。

请参见以下示例:

import pandas
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode

spark = SparkSession \
    .Builder() \
    .appName('stackoverflow') \
    .getOrCreate()

data = {
    'location_id': [1, 2, 3],
    'product_model_features': [
        [{'key': 'A', 'value': 'B'}, {'key': 'C', 'value': 'D'}, {'key': 'E', 'value': 'F'}],
        [{'key': 'G', 'value': 'H'}, {'key': 'I', 'value': 'J'}, {'key': 'K', 'value': 'L'}],
        [{'key': 'M', 'value': 'N'}, {'key': 'O', 'value': 'P'}, {'key': 'Q', 'value': 'R'}]
    ]
}
df = pandas.DataFrame(data)
df = spark.createDataFrame(df)
df = df.withColumn('p', explode('product_model_features')) \
    .select('location_id', 'p.key', 'p.value')
df.show()

输出:

+-----------+---+-----+
|location_id|key|value|
+-----------+---+-----+
|          1|  A|    B|
|          1|  C|    D|
|          1|  E|    F|
|          2|  G|    H|
|          2|  I|    J|
|          2|  K|    L|
|          3|  M|    N|
|          3|  O|    P|
|          3|  Q|    R|
+-----------+---+-----+