Spark数据集使用agg()方法计数与条件匹配的行(在Java中)

时间:2019-11-20 11:19:15

标签: java apache-spark

我在Java中使用Apache Spark 2.3.1。我想通过使用agg()类的Dataset方法来计算匹配给定条件的数据集中的行数。

例如,我要计算以下数据集中label等于1.0的行数:

SparkSession spark = ...

List<Row> rows = new ArrayList<>();
rows.add(RowFactory.create(0, 0.0));
rows.add(RowFactory.create(1, 1.0));
rows.add(RowFactory.create(2, 1.0));

Dataset<Row> ds =
    spark.sqlContext().createDataFrame(rows,
        new StructType(new StructField[] {
            new StructField("id", DataTypes.LongType, false, Metadata.empty()),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty())}));

我的猜测是使用以下代码:

ds.agg(functions.count(ds.col("label").equalTo(1.0))).show();

但是,显示错误的结果:

+--------------------+
|count((label = 1.0))|
+--------------------+
|                   3|
+--------------------+

正确的结果当然应该是2

agg()方法不应该这样工作吗?

3 个答案:

答案 0 :(得分:1)

agg()

count只计数不为null的值,因此可以这样做:

 import org.apache.spark.sql.functions._
 ds.agg(count(when('label.equalTo(1.0),1).otherwise(null))).show()

我在https://stackoverflow.com/a/1400115/9687910

处找到了此解决方案

答案 1 :(得分:0)

agg方法不应该这样工作。确实,这里您需要的是首先按照标签对数据进行分组,然后应用诸如 count max 以及更多。

df.filter("label".equalTo(1.0)).groupBy('label').agg(count("*").alias("cnt"))

它指的是以下documentation

答案 2 :(得分:0)

chlebek的答案是正确的。

使用Java语法:

obj_session = my_object.session

请注意,使用ds.agg(functions.count(functions.when(ds.col("label").equalTo(1.0), 0))).show(); 时,count函数的value自变量无关紧要(等效于SQL when)。

另一种实现此目的的方法是输出count(*)1所有结果:

sum

在这种情况下,ds.agg(functions.sum(functions.when(ds.col("label").equalTo(1.0), 1))).show(); 必须正好是value