根据来自另一个数据集的值将列添加到数据集

时间:2019-05-22 10:23:13

标签: java apache-spark apache-spark-sql

我有一个dsCustomer数据集,其中包含带有列的客户详细信息

|customerID|idpt | totalAmount|
|customer1 | H1  |    250     |
|customer2 | H2  |    175     |
|customer3 | H3  |    4000    |
|customer4 | H3  |    9000    |

我还有另一个数据集dsCategory,其中包含基于销售额的类别

|categoryID|idpt | borne_min|borne_max|
|A         |  H2 | 0        |1000     |
|B         |  H2 | 1000     |5000     |
|C         |  H2 | 5000     |7000     |
|D         |  H2 | 7000     |10000    |
|F         |  H3 | 0        |1000     |
|G         |  H3 | 1000     |5000     |
|H         |  H3 | 5000     |7000     |
|I         |  H3 | 7000     |1000000  |


我希望得到一个结果,该结果取了Customer的totalAmount并找到类别。

|customerID|idpt |totalAmount|category|
|customer1 | H1  |   250     | null   |
|customer2 | H2  |   175     | A      |
|customer3 | H3  |   4000    | G      |
|customer4 | H3  |   9000    | I      |
//udf 
public static Column getCategoryAmount(Dataset<Row> ds, Column amountColumn) {
        return ds.filter(amountColumn.geq(col("borne_min"))
                .and(amountColumn.lt(col("borne_max")))).first().getAs("categoryID");

    }

//code to add column to my dataset
dsCustomer.withColumn("category", getCategoryAmount(dsCategory , dsCustomer.col("totalAmount")));

如何将客户数据集中的列值传递给UDF函数

因为错误显示类别数据集中不包含totalAmount

问题:我应该如何使用dsCustomer中的每一行使用Map来检查它们在dsCategory中的值。

我尝试加入2个表,但是它起作用了,因为dsCustomer应该保持相同的记录,只是添加了从dsCategory中选择的计算列。

caused by: org.apache.spark.sql.AnalysisException: cannot resolve '`totalAmount`' given input columns: [categoryID,borne_min,borne_max];; 'Filter (('totalAmount>= borne_min#220) && ('totalAmount < borne_max#221))

1 个答案:

答案 0 :(得分:0)

您必须加入两个数据集。 withColumn仅允许修改相同的数据集。

更新

我没有时间详细解释我的意思。这就是我要解释的。您可以可以合并两个数据框。在您的情况下,您需要左连接来保留没有匹配类别的行。

from pyspark.sql import SparkSession


spark = SparkSession.builder.getOrCreate()

cust = [
    ('customer1', 'H1', 250), 
    ('customer2', 'H2', 175), 
    ('customer3', 'H3', 4000),
    ('customer4', 'H3', 9000)
]

cust_df = spark.createDataFrame(cust, ['customerID', 'idpt', 'totalAmount'])

cust_df.show()

cat = [
    ('A', 'H2', 0   , 1000),
    ('B', 'H2', 1000, 5000),
    ('C', 'H2', 5000, 7000),
    ('D', 'H2', 7000, 10000),
    ('F', 'H3', 0   , 1000),
    ('G', 'H3', 1000, 5000),
    ('H', 'H3', 5000, 7000),
    ('I', 'H3', 7000, 1000000)
]

cat_df = spark.createDataFrame(cat, ['categoryID', 'idpt', 'borne_min', 'borne_max'])

cat_df.show()

cust_df.join(cat_df, 
             (cust_df.idpt == cat_df.idpt) & 
             (cust_df.totalAmount >= cat_df.borne_min) & 
             (cust_df.totalAmount <= cat_df.borne_max)
            , how='left') \
.select(cust_df.customerID, cust_df.idpt, cust_df.totalAmount, cat_df.categoryID) \
.show()

输出

+----------+----+-----------+
|customerID|idpt|totalAmount|
+----------+----+-----------+
| customer1|  H1|        250|
| customer2|  H2|        175|
| customer3|  H3|       4000|
| customer4|  H3|       9000|
+----------+----+-----------+

+----------+----+---------+---------+
|categoryID|idpt|borne_min|borne_max|
+----------+----+---------+---------+
|         A|  H2|        0|     1000|
|         B|  H2|     1000|     5000|
|         C|  H2|     5000|     7000|
|         D|  H2|     7000|    10000|
|         F|  H3|        0|     1000|
|         G|  H3|     1000|     5000|
|         H|  H3|     5000|     7000|
|         I|  H3|     7000|  1000000|
+----------+----+---------+---------+

+----------+----+-----------+----------+
|customerID|idpt|totalAmount|categoryID|
+----------+----+-----------+----------+
| customer1|  H1|        250|      null|
| customer3|  H3|       4000|         G|
| customer4|  H3|       9000|         I|
| customer2|  H2|        175|         A|
+----------+----+-----------+----------+