import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import *
import os
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def split(df, validation_period):
""Logic""
return df
def train_test_split(spark, data_frame, request_json_data):
data_frame = spark.createDataFrame(data_frame)
print(data_frame.schema)
validation_period = request_json_data['validation_period']
groupby_key = request_json_data['groupby_key']
data_frame.groupby(groupby_key).apply(split, validation_period).show()
无法调用split函数,它给出错误。 apply()接受2个位置参数,但给出了3个。我想将validation_period作为参数传递给拆分函数。
答案 0 :(得分:0)
简短的回答:您不能将额外的参数传递给熊猫分组地图udf,因为它仅将一个熊猫df作为参数。
长答案:还有其他方法可以将validation_period传递给函数
使用某种形式的闭包
def split_fabric(validation_period):
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def split(df):
""Logic""
return df
将其作为列传递
data_frame \
.withColumn("validation_period", F.lit(validation_period)) \
.groupby(groupby_key).apply(split, validation_period).show()