分区中的火花模式差异

时间:2021-05-13 08:50:41

标签: apache-spark pyspark

我必须从按区域分区的路径读取数据。

美国地区有 a、b、c、d、e 列 EUR 地区只有 a,b,c,d

当我从路径读取数据并执行 printSchema 时,我只看到 a,b,c,d 'e' 丢失。

有没有办法处理这种情况?就像 e 列自动填充欧元数据的 null 一样......?

3 个答案:

答案 0 :(得分:0)

从路径中读取数据后,您可以检查数据框是否包含“e”列。如果没有,那么您可以使用默认值添加它,在这种情况下为 None。

    from pyspark.sql import SparkSession
    from pyspark.sql.functions import lit

    spark = SparkSession.builder \
                        .appName('example') \
                        .getOrCreate()
    df = spark.createDataFrame(data=data, schema = columns)
    if 'e' not in df.columns:
        df = df.withColumn('e',lit(None))

答案 1 :(得分:0)

您可以从两个数据集中收集所有可能的列,如果每个数据集中没有该列,则填充 None

df_ab = (spark
    .sparkContext
    .parallelize([
        ('a1', 'b1'),
        ('a2', 'b2'),
    ])
    .toDF(['a', 'b'])
)

df_ab.show()
# +---+---+
# |  a|  b|
# +---+---+
# | a1| b1|
# | a2| b2|
# +---+---+

df_abcd = (spark
    .sparkContext
    .parallelize([
        ('a3', 'b3', 'c3', 'd3'),
        ('a4', 'b4', 'c4', 'd4'),
    ])
    .toDF(['a', 'b', 'c', 'd'])
)
df_abcd.show()
# +---+---+---+---+
# |  a|  b|  c|  d|
# +---+---+---+---+
# | a3| b3| c3| d3|
# | a4| b4| c4| d4|
# +---+---+---+---+

unique_columns = list(set(df_ab.columns + df_abcd.columns))
# ['d', 'b', 'a', 'c']

for col in unique_columns:
    if col not in df_ab.columns:
        df_ab = df_ab.withColumn(col, F.lit(None))
    if col not in df_abcd.columns:
        df_abcd = df_abcd.withColumn(col, F.lit(None))

df_ab.printSchema()
# root
#  |-- a: string (nullable = true)
#  |-- b: string (nullable = true)
#  |-- d: null (nullable = true)
#  |-- c: null (nullable = true)

df_ab.show()
# +---+---+----+----+
# |  a|  b|   d|   c|
# +---+---+----+----+
# | a1| b1|null|null|
# | a2| b2|null|null|
# +---+---+----+----+
    
df_abcd.printSchema()
# root
#  |-- a: string (nullable = true)
#  |-- b: string (nullable = true)
#  |-- c: string (nullable = true)
#  |-- d: string (nullable = true)

df_abcd.show()
# +---+---+---+---+
# |  a|  b|  c|  d|
# +---+---+---+---+
# | a3| b3| c3| d3|
# | a4| b4| c4| d4|
# +---+---+---+---+

答案 2 :(得分:0)

我使用了 pyspark 和 SQLContext。希望这个实现能帮助你得到一个想法。 Spark 提供了一个使用 SQL 的环境,使用 SPARK SQL 来处理这些类型的事情非常方便。

from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql import functions
from pyspark.sql import SQLContext
import sys
import os        
from pyspark.sql.types import StructType,StructField, StringType, IntegerType

class getData(object):
"""docstring for getData"""

  def __init__(self):
  def get_data(self, n):
    
    spark = SparkSession.builder.appName('YourProjectName').getOrCreate()

    data2 = [("region 1","region 2","region 3","region 4"),
            ("region 5","region 6","region 7","region 8")
            
            ]

    schema = StructType([ \
                StructField("a",StringType(),True), \
                StructField("b",StringType(),True), \
                StructField("c",StringType(),True), \
                StructField("d", StringType(), True) \
                
                ])

    data3 = [("EU region 1","EU region 2","EU region 3"),
            ("EU region 5","EU region 6","EU region 7")
            
            ]

    schema3 = StructType([ \
                StructField("a",StringType(),True), \
                StructField("b",StringType(),True), \
                StructField("c",StringType(),True) \
                
                
                ])

    df = spark.createDataFrame(data=data2,schema=schema)
    df.createOrReplaceTempView("USRegion")
    sqlDF = self.sparkSession1.sql("SELECT * FROM USRegion")
    sqlDF.show(n=600)

    df1 = spark.createDataFrame(data=data3,schema=schema3)
    df1.createOrReplaceTempView("EURegion")
    sqlDF1 = self.sparkSession1.sql("SELECT * FROM EURegion")
    sqlDF1.show(n=600)
    

    sql_union_df = self.sparkSession1.sql("SELECT a, b, c, d FROM USRegion  uNION ALL SELECT a,b, c, '' as d FROM EURegion ")
    sql_union_df.show(n=600)


#call the class
conn = getData()
#call the method implemented inside the class
print(conn.get_data(10))

Answer