Spark上的多个相同阶段

时间:2018-07-11 11:26:04

标签: apache-spark directed-acyclic-graphs

我已在Spark上实施了高斯消除法。代码如下

public class GaussianElimination {

public static void main(String args[]) {

    SparkConf conf = new SparkConf();
    JavaSparkContext sc = new JavaSparkContext(conf);
    JavaRDD<String> lines = sc.textFile("E:\\Gaussian Elimination\\3x3.csv");
    JavaRDD<Vector> rows = lines.map(new Function<String, Vector>() {

        @Override
        public Vector call(String line) throws Exception {
            double[] arr = new double[3];
            String[] strArr = line.split(",");

            for (int i = 0; i < strArr.length; i++) {
                arr[i] = Double.parseDouble(strArr[i]);
            }

            Vector vec = Vectors.dense(arr);
            return vec;
        }

    });

    JavaPairRDD<Long, Vector> row = rows.zipWithIndex()
            .mapToPair(new PairFunction<Tuple2<Vector, Long>, Long, Vector>() {

                @Override
                public Tuple2<Long, Vector> call(Tuple2<Vector, Long> tuple) throws Exception {
                    return new Tuple2<Long, Vector>(tuple._2, tuple._1);
                }
            });

    JavaPairRDD<Long, Vector> divide = null;
    JavaPairRDD<Long, Vector> result = null;
    for (long i = 0; i < 2; i++) {
        Broadcast<Long> index = sc.broadcast(i);
        divide = row.mapToPair(new PairFunction<Tuple2<Long, Vector>, Long, Vector>() {

            @Override
            public Tuple2<Long, Vector> call(Tuple2<Long, Vector> tuple) throws Exception {
                long indx = index.getValue();
                Vector vec = tuple._2;
                if (tuple._1 == indx) {
                    double[] arr = new double[3];
                    arr = vec.toArray();
                    double pivot = arr[(int) indx];
                    for (int j = 0; j < 3; j++) {
                        arr[j] = arr[j] / pivot;
                    }
                    vec = Vectors.dense(arr);

                }
                return new Tuple2<Long, Vector>(tuple._1, vec);
            }
        }); 

        //System.out.println("/////////////////////"+divide.collect()+"///////////////////////");

        JavaPairRDD<Long, Vector> replicate = divide
                .flatMapToPair(new PairFlatMapFunction<Tuple2<Long, Vector>, Long, Vector>() {

                    @Override
                    public Iterator<Tuple2<Long, Vector>> call(Tuple2<Long, Vector> tuple) throws Exception {
                        long indx = index.getValue();
                        Vector vec = tuple._2;
                        ArrayList<Tuple2<Long, Vector>> list = new ArrayList<Tuple2<Long, Vector>>();

                        if (tuple._1 == indx) {
                            for (int i = 0; i < 3; i++) {
                                list.add(new Tuple2<Long, Vector>((long) i, vec));
                            }
                        }
                        return list.iterator();
                    }

                });

        //System.out.println("/////////////////////"+replicate.collect()+"///////////////////////");

        JavaPairRDD<Long, Tuple2<Vector, Vector>> change = divide.join(replicate);

        //System.out.println("/////////////////////"+change.collect()+"///////////////////////");

        result = change
                .mapToPair(new PairFunction<Tuple2<Long, Tuple2<Vector, Vector>>, Long, Vector>() {

                    @Override
                    public Tuple2<Long, Vector> call(Tuple2<Long, Tuple2<Vector, Vector>> tuple) throws Exception {
                        long indx = index.getValue();

                        Vector vec1 = tuple._2._1;
                        Vector vec2 = tuple._2._2;
                        double[] arr1 = new double[3];
                        double[] arr2 = new double[3];

                        arr1 = vec1.toArray();
                        arr2 = vec2.toArray();
                        double pivot = arr1[(int) indx];
                        if (tuple._1 != indx) {
                            for (int i = 0; i < 3; i++) {
                                arr1[i] = arr1[i] - arr2[i] * pivot;                                    
                            }
                        }
                        Vector result = Vectors.dense(arr1);

                        return new Tuple2<Long, Vector>(tuple._1, result);
                    }

                });

        //System.out.println("////////////"+result.collect()+"////////////////////////");
        row = result;
    }

    row.count();

    sc.close();
}

}

DAG如下进行2次迭代

enter image description here

我的问题如下:

1)为什么要创建多余的阶段2。就像阶段3一样,可以从阶段1本身的结果中合并map和flatmap的输出。

2)为什么创建第4阶段也是冗余的?

3)阶段(阶段3)如何从联接转换开始?和

4)摆脱这些额外阶段的方法是什么?

谢谢。

0 个答案:

没有答案