在Pyspark中使用udf时出现__getnewargs__错误

时间:2019-10-11 21:23:21

标签: python dataframe pyspark

有一个包含两列(db和tb)的datafarame:db代表数据库,而tb代表该数据库的tableName。

   +--------------------+--------------------+
   |            database|           tableName|
   +--------------------+--------------------+
   |aaaaaaaaaaaaaaaaa...|    tttttttttttttttt|
   |bbbbbbbbbbbbbbbbb...|    rrrrrrrrrrrrrrrr|
   |aaaaaaaaaaaaaaaaa...|  ssssssssssssssssss|

我在python中有以下方法:

 def _get_tb_db(db, tb):
      df = spark.sql("select * from {}.{}".format(db, tb))
      return df.dtypes

和这个udf:

 test = udf(lambda db, tb: _get_tb_db(db, tb), StringType())

运行此程序:

 df = df.withColumn("dtype", test(col("db"), col("tb")))

出现以下错误:

 pickle.PicklingError: Could not serialize object: Py4JError: An 
 error occurred while calling o58.__getnewargs__. Trace:
 py4j.Py4JException: Method __getnewargs__([]) does not exist

我发现了一些关于stackoverflow的讨论:Spark __getnewargs__ error 但我不确定如何解决此问题? 错误是因为我正在UDF内创建另一个数据框吗?

类似于链接中的解决方案,我尝试了以下方法:

       cols = copy.deepcopy(df.columns)
       df = df.withColumn("dtype", scanning(cols[0], cols[1]))

但仍然出现错误

有解决方案吗?

1 个答案:

答案 0 :(得分:1)

该错误表示您不能在UDF中使用 Spark数据帧。但是由于包含数据库和表名称的数据框很可能很小,因此只需执行Python for循环就足够了,下面是一些可能有助于获取数据的方法:

from pyspark.sql import Row

# assume dfs is the df containing database names and table names
dfs.printSchema()
root
 |-- database: string (nullable = true)
 |-- tableName: string (nullable = true)

方法1:使用df.dtypes

运行sql select * from database.tableName limit 1以生成df并返回其dtype,然后将其转换为StringType()。

data = []
DRow = Row('database', 'tableName', 'dtypes')
for row in dfs.collect():
  try:
    dtypes = spark.sql('select * from `{}`.`{}` limit 1'.format(row.database, row.tableName)).dtypes
    data.append(DRow(row.database, row.tableName, str(dtypes)))
  except Exception, e:
    print("ERROR from {}.{}: [{}]".format(row.database, row.tableName, e))
    pass

df_dtypes = spark.createDataFrame(data)
# DataFrame[database: string, tableName: string, dtypes: string]

注意:

  • 使用dtypes而不是str(dtypes)将得到以下模式,其中_1_2分别是col_namecol_dtype

    root
     |-- database: string (nullable = true)
     |-- tableName: string (nullable = true)
     |-- dtypes: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- _1: string (nullable = true)
     |    |    |-- _2: string (nullable = true)
    
  • 使用此方法,每个表将只有一行。对于接下来的两个方法,表的每个col_type将具有其自己的行。

方法2:使用describe

您还可以通过运行spark.sql("describe tableName")来检索此信息,通过该操作直接获取数据帧,然后使用reduce函数合并所有表的结果。

from functools import reduce

def get_df_dtypes(db, tb):
  try:
    return spark.sql('desc `{}`.`{}`'.format(db, tb)) \
                .selectExpr(
                      '"{}" as `database`'.format(db)
                    , '"{}" as `tableName`'.format(tb)
                    , 'col_name'
                    , 'data_type')
  except Exception, e:
    print("ERROR from {}.{}: [{}]".format(db, tb, e))
    pass

# an example table:
get_df_dtypes('default', 'tbl_df1').show()
+--------+---------+--------+--------------------+
|database|tableName|col_name|           data_type|
+--------+---------+--------+--------------------+
| default|  tbl_df1| array_b|array<struct<a:st...|
| default|  tbl_df1| array_d|       array<string>|
| default|  tbl_df1|struct_c|struct<a:double,b...|
+--------+---------+--------+--------------------+

# use reduce function to union all tables into one df
df_dtypes = reduce(lambda d1, d2: d1.union(d2), [ get_df_dtypes(row.database, row.tableName) for row in dfs.collect() ])

方法3:使用spark.catalog.listColumns()

使用spark.catalog.listColumns()创建一个collections.Column对象的列表,检索namedataType并合并数据。生成的数据帧在其自己的列上使用col_name和col_dtype进行了规范化(与使用 Method-2 相同)。

data = []
DRow = Row('database', 'tableName', 'col_name', 'col_dtype')
for row in dfs.select('database', 'tableName').collect():
  try:
    for col in spark.catalog.listColumns(row.tableName, row.database):
      data.append(DRow(row.database, row.tableName, col.name, col.dataType))
  except Exception, e:
    print("ERROR from {}.{}: [{}]".format(row.database, row.tableName, e))
    pass

df_dtypes = spark.createDataFrame(data)
# DataFrame[database: string, tableName: string, col_name: string, col_dtype: string]

A注意::在检索元数据时,不同的Spark发行版/版本可能与describe tbl_name和其他命令产生不同的结果,请确保在查询中使用正确的列名。