我已在Spark上实施了高斯消除法。代码如下
public class GaussianElimination {
public static void main(String args[]) {
SparkConf conf = new SparkConf();
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> lines = sc.textFile("E:\\Gaussian Elimination\\3x3.csv");
JavaRDD<Vector> rows = lines.map(new Function<String, Vector>() {
@Override
public Vector call(String line) throws Exception {
double[] arr = new double[3];
String[] strArr = line.split(",");
for (int i = 0; i < strArr.length; i++) {
arr[i] = Double.parseDouble(strArr[i]);
}
Vector vec = Vectors.dense(arr);
return vec;
}
});
JavaPairRDD<Long, Vector> row = rows.zipWithIndex()
.mapToPair(new PairFunction<Tuple2<Vector, Long>, Long, Vector>() {
@Override
public Tuple2<Long, Vector> call(Tuple2<Vector, Long> tuple) throws Exception {
return new Tuple2<Long, Vector>(tuple._2, tuple._1);
}
});
JavaPairRDD<Long, Vector> divide = null;
JavaPairRDD<Long, Vector> result = null;
for (long i = 0; i < 2; i++) {
Broadcast<Long> index = sc.broadcast(i);
divide = row.mapToPair(new PairFunction<Tuple2<Long, Vector>, Long, Vector>() {
@Override
public Tuple2<Long, Vector> call(Tuple2<Long, Vector> tuple) throws Exception {
long indx = index.getValue();
Vector vec = tuple._2;
if (tuple._1 == indx) {
double[] arr = new double[3];
arr = vec.toArray();
double pivot = arr[(int) indx];
for (int j = 0; j < 3; j++) {
arr[j] = arr[j] / pivot;
}
vec = Vectors.dense(arr);
}
return new Tuple2<Long, Vector>(tuple._1, vec);
}
});
//System.out.println("/////////////////////"+divide.collect()+"///////////////////////");
JavaPairRDD<Long, Vector> replicate = divide
.flatMapToPair(new PairFlatMapFunction<Tuple2<Long, Vector>, Long, Vector>() {
@Override
public Iterator<Tuple2<Long, Vector>> call(Tuple2<Long, Vector> tuple) throws Exception {
long indx = index.getValue();
Vector vec = tuple._2;
ArrayList<Tuple2<Long, Vector>> list = new ArrayList<Tuple2<Long, Vector>>();
if (tuple._1 == indx) {
for (int i = 0; i < 3; i++) {
list.add(new Tuple2<Long, Vector>((long) i, vec));
}
}
return list.iterator();
}
});
//System.out.println("/////////////////////"+replicate.collect()+"///////////////////////");
JavaPairRDD<Long, Tuple2<Vector, Vector>> change = divide.join(replicate);
//System.out.println("/////////////////////"+change.collect()+"///////////////////////");
result = change
.mapToPair(new PairFunction<Tuple2<Long, Tuple2<Vector, Vector>>, Long, Vector>() {
@Override
public Tuple2<Long, Vector> call(Tuple2<Long, Tuple2<Vector, Vector>> tuple) throws Exception {
long indx = index.getValue();
Vector vec1 = tuple._2._1;
Vector vec2 = tuple._2._2;
double[] arr1 = new double[3];
double[] arr2 = new double[3];
arr1 = vec1.toArray();
arr2 = vec2.toArray();
double pivot = arr1[(int) indx];
if (tuple._1 != indx) {
for (int i = 0; i < 3; i++) {
arr1[i] = arr1[i] - arr2[i] * pivot;
}
}
Vector result = Vectors.dense(arr1);
return new Tuple2<Long, Vector>(tuple._1, result);
}
});
//System.out.println("////////////"+result.collect()+"////////////////////////");
row = result;
}
row.count();
sc.close();
}
}
DAG如下进行2次迭代
我的问题如下:
1)为什么要创建多余的阶段2。就像阶段3一样,可以从阶段1本身的结果中合并map和flatmap的输出。
2)为什么创建第4阶段也是冗余的?
3)阶段(阶段3)如何从联接转换开始?和
4)摆脱这些额外阶段的方法是什么?
谢谢。