我正在使用{s {1}}启动的pyspark测试pandas_udf
的分组地图功能。
PYSPARK_PYTHON=python3 pyspark
这是代码:
Python version: 3.6.8
pyarrow: 0.13.0
pyspark: 2.4.1
根据PySpark Usage Guide for Pandas with Apache Arrow,预期结果如下。
import numpy as np
import pandas as pd
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
def subtract_mean(pdf):
# pdf is a pandas.DataFrame
v = pdf.v
return pdf.assign(v=v - v.mean())
df.groupby("id").apply(subtract_mean).show()
但是我拥有的是
# +---+----+
# | id| v|
# +---+----+
# | 1|-0.5|
# | 1| 0.5|
# | 2|-3.0|
# | 2|-1.0|
# | 2| 4.0|
# +---+----+
但是,此代码可以在Spark 2.3.1中按预期工作。我在某个地方错了还是这是一个错误?