我在Java8中使用spark-sql-2.4.1version
我有一个场景,我必须将计算出的数据集收集/减少到通用/结果数据集中,但是我遇到了错误
下面是代码段:
public static void main(String[] args) {
List<String[]> stringAsList = new ArrayList<>();
stringAsList.add(new String[] { "6","7","8" ,"A" });
stringAsList.add(new String[] { "82","72","58","A" });
stringAsList.add(new String[] { "63","17","18","B" });
stringAsList.add(new String[] { "16","70","81","A" });
stringAsList.add(new String[] { "69","34","8","B" });
stringAsList.add(new String[] { "82","72","58","A" });
stringAsList.add(new String[] { "63","17","18","B" });
stringAsList.add(new String[] { "16","70","81","A" });
stringAsList.add(new String[] { "69","34","8","B" });
JavaSparkContext sparkContext = new JavaSparkContext(spark.sparkContext());
JavaRDD<Row> rowRDD = sparkContext.parallelize(stringAsList).map((String[] row) -> RowFactory.create(row));
StructType schema = DataTypes
.createStructType(new StructField[] {
DataTypes.createStructField("code1", DataTypes.StringType, false),
DataTypes.createStructField("code2", DataTypes.StringType, false),
DataTypes.createStructField("code3", DataTypes.StringType, false),
DataTypes.createStructField("class", DataTypes.StringType, false)
});
Dataset<Row> dataDf= spark.sqlContext().createDataFrame(rowRDD, schema).toDF();
Dataset<Row> data= dataDf
.withColumn("code1", col("code1").cast(DataTypes.IntegerType))
.withColumn("code2", col("code2").cast(DataTypes.IntegerType))
.withColumn("code3", col("code3").cast(DataTypes.IntegerType))
.withColumn("class", col("class").cast(DataTypes.StringType))
;
//data.show();
List<String> interestedCols = Arrays.asList("code1","code3");
// reduced final result dataset
Dataset<Row> resultDs = interestedCols.stream().reduce(data, dataAccumlator, datasetCombiner);
resultDs.show();
System.out.println("Done");
}
Dataset<Row> dataAccumlator(Dataset<Row> df, String coll) {
Dataset<Row> dd = df.groupBy("class").agg( avg(coll).alias("mean"), expr("percentile("+ coll + ", array(0.5, 0.4, 0.1) )").alias("percentiles"));
return dd
.withColumn("per50", col("percentiles").getItem(0) )
.withColumn("coll", lit(coll));
}
Dataset<Row> datasetCombiner(Dataset<Row> left, Dataset<Row> right) {
return left.union(right);
}
提供输入数据:
+-----+-----+-----+-----+
|code1|code2|code3|class|
+-----+-----+-----+-----+
| 6| 7| 8| A|
| 82| 72| 58| A|
| 63| 17| 18| B|
| 16| 70| 81| A|
| 69| 34| 8| B|
| 82| 72| 58| A|
| 63| 17| 18| B|
| 16| 70| 81| A|
| 69| 34| 8| B|
+-----+-----+-----+-----+
预期输出
+-----+----+------------------+-----+-----+
|class|mean| percentiles|per50| coll|
+-----+----+------------------+-----+-----+
| B|66.0|[66.0, 64.2, 63.0]| 66.0|code1|
| A|40.4|[16.0, 16.0, 10.0]| 16.0|code1|
| B|13.0|[13.0, 10.0000000]| 13.0|code3|
| A|57.2|[58.0, 58.0, 28.0]| 58.0|code3|
+-----+----+------------------+-----+-----+
遇到错误:
org.apache.spark.sql.AnalysisException: cannot resolve '`code3`' given input columns: [class, mean, count, percentiles];;
如何解决此错误并获得所需的输出? 我在这里做什么错了,这里有什么解决方法?