使用Java将索引列添加到Apache Spark Dataset <row>

时间:2019-05-16 07:42:23

标签: java apache-spark

以下问题为scala和pyspark提供了解决方案,此问题中提供的解决方案不适用于连续的索引值。

Spark Dataframe :How to add a index Column : Aka Distributed Data Index

我在Apache-spark中有一个现有的数据集,我想根据索引从中选择一些行。我打算添加一个索引列,其中包含从1开始的唯一值,并且基于该列的值,我将获取行。 我发现以下添加使用order by的索引的方法:

df.withColumn("index", functions.row_number().over(Window.orderBy("a column")));

我不想使用order by。我需要按照它们在数据集中出现的顺序排列索引。有帮助吗?

2 个答案:

答案 0 :(得分:0)

根据我的收集,您正在尝试向数据帧添加索引(具有连续值)。不幸的是,Spark中没有内置函数可以做到这一点。您只能使用df.withColumn(“ index”,monotonicallyIncreasingId)添加一个递增的索引(但不一定具有连续值)。

尽管如此,RDD API中仍然存在一个zipWithIndex函数,它可以完全满足您的需求。因此,我们可以定义一个函数,将数据帧转换为RDD,添加索引并将其转换回数据帧。

我不是Java火花专家(scala更为紧凑),因此可能会做得更好。这就是我的做法。

public static Dataset<Row> zipWithIndex(Dataset<Row> df, String name) {
    JavaRDD<Row> rdd = df.javaRDD().zipWithIndex().map(t -> {
        Row r = t._1;
        Long index = t._2 + 1;
        ArrayList<Object> list = new ArrayList<>();
        r.toSeq().iterator().foreach(x -> list.add(x));
        list.add(index);
        return RowFactory.create(list);
    });
    StructType newSchema = df.schema()
            .add(new StructField(name, DataTypes.LongType, true, null));
    return df.sparkSession().createDataFrame(rdd, newSchema);
}

这是您将如何使用它。请注意,与我们的方法相反,内置的spark函数的作用。

Dataset<Row> df = spark.range(5)
    .withColumn("index1", functions.monotonicallyIncreasingId());
Dataset<Row> result = zipWithIndex(df, "good_index");
// df
+---+-----------+
| id|     index1|
+---+-----------+
|  0|          0|
|  1| 8589934592|
|  2|17179869184|
|  3|25769803776|
|  4|25769803777|
+---+-----------+

// result
+---+-----------+----------+
| id|     index1|good_index|
+---+-----------+----------+
|  0|          0|         1|
|  1| 8589934592|         2|
|  2|17179869184|         3|
|  3|25769803776|         4|
|  4|25769803777|         5|
+---+-----------+----------+

答案 1 :(得分:0)

上面的答案对我有些调整。以下是功能正常的Intellij Scratch文件。我正在使用Spark 2.3.0:

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;

class Scratch {
    public static void main(String[] args) {
        SparkSession spark = SparkSession
                    .builder()
                    .appName("_LOCAL")
                    .master("local")
                    .getOrCreate();
        Dataset<Row> df = spark.range(5)
                .withColumn("index1", functions.monotonicallyIncreasingId());
        Dataset<Row> result = zipWithIndex(df, "good_index");
        result.show();
    }
    public static Dataset<Row> zipWithIndex(Dataset<Row> df, String name) {
        JavaRDD<Row> rdd = df.javaRDD().zipWithIndex().map(t -> {
            Row r = t._1;
            Long index = t._2 + 1;
            ArrayList<Object> list = new ArrayList<>();
            scala.collection.Iterator<Object> iterator = r.toSeq().iterator();
            while(iterator.hasNext()) {
                Object value = iterator.next();
                assert value != null;
                list.add(value);
            }
            list.add(index);
            return RowFactory.create(list.toArray());
        });
        StructType newSchema = df.schema()
                .add(new StructField(name, DataTypes.LongType, true, Metadata.empty()));
        return df.sparkSession().createDataFrame(rdd, newSchema);
    }
}

输出:

+---+------+----------+
| id|index1|good_index|
+---+------+----------+
|  0|     0|         1|
|  1|     1|         2|
|  2|     2|         3|
|  3|     3|         4|
|  4|     4|         5|
+---+------+----------+