Spark代码给出错误的矩阵乘法结果

时间:2015-05-21 04:12:10

标签: apache-spark matrix-multiplication

我有两个矩阵{1,2,3; 4,5,6; 7,8,9}和{1,4; 2,5; 3,6}。 以下代码是apache spark中的矩阵乘法。但它给我错误的输出{15.0,29.0; 36.0,71.0; 57.0,113.0}。我想知道我做错了哪里?

    JavaRDD<String> lines = ctx
            .textFile(
                    "/home/hduser/Desktop/interpolation/Kriging/MatrixMultiplication/MatrixA.csv")
            .cache();

    JavaRDD<String> lines1 = ctx
            .textFile(
                    "/home/hduser/Desktop/interpolation/Kriging/MatrixMultiplication/MatrixB.csv")
            .cache();

    JavaRDD<Vector> rows = lines.map(new Function<String, Vector>() {

        @Override
        public Vector call(String line) throws Exception {

            String[] lineSplit = line.split(",");
            double[] arr = new double[lineSplit.length];
            for (int i = 0; i < lineSplit.length; i++) {
                arr[i] = Double.parseDouble(lineSplit[i]);
            }
            Vector dv = Vectors.dense(arr);
            return dv;
        }

    });

    //rows.saveAsTextFile("/home/hduser/Desktop/interpolation/Kriging/MatrixMultiplication/MatrixA_output");

    RowMatrix A = new RowMatrix(rows.rdd());



    JavaRDD<Vector> rows1 = lines1.map(new Function<String, Vector>() {

        @Override
        public Vector call(String line) throws Exception {

            String[] lineSplit = line.split(",");
            double[] arr = new double[lineSplit.length];
            for (int i = 0; i < lineSplit.length; i++) {
                arr[i] = Double.parseDouble(lineSplit[i]);
            }
            Vector dv = Vectors.dense(arr);
            return dv;
        }

    });

    List<Vector> arrList = new ArrayList<Vector>();
    arrList = rows1.toArray();


    double[] arr1 = new double[(int) rows1.count() * arrList.get(0).size()];
    int k=0;
    for (int i = 0; i < arrList.size(); i++) {
        for (int j = 0; j < arrList.get(i).size(); j++) {
            arr1[k] = arrList.get(i).apply(j);
            //System.out.println(arr1[k]);
            k++;
        }
    }
    Matrix B = Matrices.dense((int) rows1.count(), arrList.get(0)
            .size(), arr1);

    RowMatrix C = A.multiply(B);


    RDD<Vector> rows2 = C.rows();
    rows2.saveAsTextFile("/home/hduser/Desktop/interpolation/Kriging/MatrixMultiplication/Result");

提前致谢...

1 个答案:

答案 0 :(得分:1)

Matrices.dense构造一个列主矩阵(API doc),并且您以错误的顺序遍历行数组。

我无法查看您的CSV文件,但我猜测您也有错字。为什么呢?

B必须为[1 3; 4 5; 2 6]才能生成错误的输出,因此数组必须为{1,4,2,3,5,6},因此MatrixB.csv可能包含:

1,4
2,3
5,6

(切换3和5)