Pyspark Dataframe取列中的列表的平均值并创建具有1&的新列。 0取决于条件

时间:2017-09-08 04:31:54

标签: pyspark spark-dataframe

我正在尝试计算Pyspark Dataframe列中列表(成本)的平均值,小于平均值的值得到值1,高于平均值0。

This is the current dataframe
+----------+--------------------+--------------------+
|        id|  collect_list(p_id)|collect_list(cost)  |
+----------+--------------------+--------------------+
|         7|[10, 987, 872]      |[12.0, 124.6, 197.0]|
|         6|[11, 858, 299]      |[15.0, 167.16, 50.0]|
|        17|                 [2]|           [65.4785]|
|         1|[34359738369, 343...|[16.023384, 104.9...|
|         3|[17179869185, 0, ...|[48.3255, 132.025...|
+----------+--------------------+--------------------+


This is the desired output

+----------+--------------------+--------------------+-----------+
|        id|    p_id            |cost                | result    |
+----------+--------------------+--------------------+-----------+
|         7|10                  |12.0                |  1        |
|         7|987                 |124.6               |  0        |
|         7|872                 |197.0               |  0        |
|         6|11                  |15.0                |  1        |
|         6|858                 |167.16              |  0        |
|         6|299                 |50.0                |  1        |
|        17|2                   |65.4785             |  1        |
+----------+--------------------+--------------------+-----------+

非常感谢任何帮助,非常感谢。

2 个答案:

答案 0 :(得分:0)

您可以为每一行创建结果列表,然后压缩pid,成本和结果列表。之后使用在压缩柱上爆炸。

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import *
def zip_cols(pid_list,cost_list):
    mean = np.mean(cost_list)
    res_list = list(map(lambda cost:1 if mean >= cost else 0,cost_list))
    return[(x,y,z) for x,y,z in zip(pid_list, cost_list, res_list)]

udf_zip = udf(zip_cols, ArrayType(StructType([StructField("pid",IntegerType()),
                                              StructField("cost", DoubleType()), 
                                              StructField("result",IntegerType())])))
df1 = (df.withColumn("temp",udf_zip("collect_list(p_id)","collect_list(cost)")).
        drop("collect_list(p_id)","collect_list(cost)"))

df2 =   (df1.withColumn("temp",explode(df1.temp)).
        select("id",col("temp.pid").alias("pid"),
               col("temp.cost").alias("cost"),
               col("temp.result").alias("result")))
df2.show()

输出

+---+---+-------+------+
| id|pid|   cost|result|
+---+---+-------+------+
|  7| 10|   12.0|     1|
|  7| 98|  124.6|     0|
|  7|872|  197.0|     0|
|  6| 11|   15.0|     1|
|  6|858| 167.16|     0|
|  6|299|   50.0|     1|
| 17|  2|65.4758|     1|
+---+---+-------+------+

答案 1 :(得分:0)

希望这有帮助!

from pyspark.sql.functions import col, mean

#sample data
df = sc.parallelize([(7,[10, 987, 872],[12.0, 124.6, 197.0]),
                     (6,[11, 858, 299],[15.0, 167.16, 50.0]),
                     (17,[2],[65.4785])]).toDF(["id", "collect_list(p_id)","collect_list(cost)"])

#unpack collect_list in desired output format
df = df.rdd.flatMap(lambda row: [(row[0], x, y) for x,y in zip(row[1],row[2])]).toDF(["id", "p_id","cost"])
df1 = df.\
    join(df.groupBy("id").agg(mean("cost").alias("mean_cost")), "id", 'left').\
    withColumn("result",(col("cost") <= col("mean_cost")).cast("int")).\
    drop("mean_cost")
df1.show()

输出是:

+---+----+-------+------+
| id|p_id|   cost|result|
+---+----+-------+------+
|  7|  10|   12.0|     1|
|  7| 987|  124.6|     0|
|  7| 872|  197.0|     0|
|  6|  11|   15.0|     1|
|  6| 858| 167.16|     0|
|  6| 299|   50.0|     1|
| 17|   2|65.4785|     1|
+---+----+-------+------+