我正在计算PySpark DataFrame中嵌套数据products
的平均值和标准差。
+----------+--------------------------------+
|product_PK| products|
+----------+--------------------------------+
| 686 | [[686,520.70],[645,2]]|
| 685 |[[685,45.556],[678,23],[655,21]]|
| 693 | []|
问题是我得到的是平均值和标准差的值。最有可能的原因是代码不考虑[]
。空值应替换为0。
此外,IntegerType可能应该是一个浮点值。
如何获得正确的结果而不是无?
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import explode, col, udf, mean as mean_, stddev as stddev_
df = sqlCtx.createDataFrame(
[(686, [[686,520.70], [645,2]]), (685, [[685,45.556], [678,23],[655,21]]), (693, [])],
["product_PK", "products"]
)
get_score = udf(lambda x: x[1], IntegerType())
df_stats = df.withColumn('exploded', explode(col('products')))\
.withColumn('score', get_score(col('exploded')))\
.select(
mean_(col('score')).alias('mean'),
stddev_(col('score')).alias('std')
)\
.collect()
mean = df_stats[0]['mean']
std = df_stats[0]['std']
print([mean, std])
答案 0 :(得分:2)
首先,您不需要UDF来从数组中获取项目
其次,只需使用na.fill
用数字填充NULL值(在你的情况下为零)
df.withColumn("exploded" , explode(col("products") ) )
.withColumn("score", col("exploded").getItem(1) )
.na.fill(0)
.select(
mean_(col("score") ).alias("mean") ,
stddev_(col("score") ).alias("stddev")
)
.show()
+----+------------------+
|mean| stddev|
+----+------------------+
| 9.2|11.734564329364767|
+----+------------------+
要在变量中单独获取值:
row = df.withColumn("exploded" , explode(col("products") ) )
.withColumn("score", col("exploded").getItem(1) )
.na.fill(0)
.select(
mean_(col("score") ).alias("mean") ,
stddev_(col("score") ).alias("stddev")
)
.first()
mean = row.mean
stddev = row.stddev