我在Java中使用Apache Spark 2.3.1。我想通过使用agg()
类的Dataset
方法来计算匹配给定条件的数据集中的行数。
例如,我要计算以下数据集中label
等于1.0
的行数:
SparkSession spark = ...
List<Row> rows = new ArrayList<>();
rows.add(RowFactory.create(0, 0.0));
rows.add(RowFactory.create(1, 1.0));
rows.add(RowFactory.create(2, 1.0));
Dataset<Row> ds =
spark.sqlContext().createDataFrame(rows,
new StructType(new StructField[] {
new StructField("id", DataTypes.LongType, false, Metadata.empty()),
new StructField("label", DataTypes.DoubleType, false, Metadata.empty())}));
我的猜测是使用以下代码:
ds.agg(functions.count(ds.col("label").equalTo(1.0))).show();
但是,显示错误的结果:
+--------------------+
|count((label = 1.0))|
+--------------------+
| 3|
+--------------------+
正确的结果当然应该是2
。
agg()
方法不应该这样工作吗?
答案 0 :(得分:1)
agg()
中count只计数不为null的值,因此可以这样做:
import org.apache.spark.sql.functions._
ds.agg(count(when('label.equalTo(1.0),1).otherwise(null))).show()
处找到了此解决方案
答案 1 :(得分:0)
agg
方法不应该这样工作。确实,这里您需要的是首先按照标签对数据进行分组,然后应用诸如 count , max 以及更多。
df.filter("label".equalTo(1.0)).groupBy('label').agg(count("*").alias("cnt"))
它指的是以下documentation。
答案 2 :(得分:0)
chlebek的答案是正确的。
使用Java语法:
obj_session = my_object.session
请注意,使用ds.agg(functions.count(functions.when(ds.col("label").equalTo(1.0), 0))).show();
时,count
函数的value
自变量无关紧要(等效于SQL when
)。
另一种实现此目的的方法是输出count(*)
和1
所有结果:
sum
在这种情况下,ds.agg(functions.sum(functions.when(ds.col("label").equalTo(1.0), 1))).show();
必须正好是value
。