使用Strassen算法的矩阵乘法C代码有什么问题?

时间:2018-06-13 18:18:11

标签: c algorithm matrix-multiplication

// Probably some mistake in "conquer" part of the recursive function "multiplySquareMatrices".

#include<stdio.h>

void showSquareMatrix(int n, int[][n]);
void copyMatrix_Complete2Part(int n, int[][n], int, int, int, int, int m, 
int[][m]);
void subSquareMatrices(int n, int[][n], int[][n], int[][n]);
void addSquareMatrices(int n, int[][n], int[][n], int[][n]);
void copyMatrix_Part2Complete(int n, int[][n], int m, int[][m], int, int, int, int);
void multiplySquareMatrices(int n, int[][n], int[][n], int[][n]);
void acceptSquareMatrix(int n, int[][n]);

int main()
{
    int n = 2;
    int matrixA[n][n];
    int matrixB[n][n];
    int matrixAB[n][n];

    acceptSquareMatrix(n, matrixA);
    acceptSquareMatrix(n, matrixB);

    multiplySquareMatrices(n, matrixAB, matrixA, matrixB);
    showSquareMatrix(n, matrixAB);
}

void acceptSquareMatrix(int n, int matrix[][n])     // matrix[][columns] should not be changed.
{
    int i;
    int j;
    for(i = 0; i < n; i++)
    {
        for(j = 0; j < n; j++)
        {
            printf("\n[%d][%d]: ", i, j);
            scanf("%d", &matrix[i][j]);
        }
    }
}

void multiplySquareMatrices(int n, int matrixAB[][n], int matrixA[][n], int matrixB[][n])
{
    if(n == 1)
    {
        matrixAB[0][0] = matrixA[0][0] * matrixB[0][0];
    }
    else
    {
        //divide
        int matrix_a[n / 2][n / 2];
        int matrix_b[n / 2][n / 2];
        int matrix_c[n / 2][n / 2];
        int matrix_d[n / 2][n / 2];

        int matrix_e[n / 2][n / 2];
        int matrix_f[n / 2][n / 2];
        int matrix_g[n / 2][n / 2];
        int matrix_h[n / 2][n / 2];

        int rowStart_a = 0;     int rowEnd_a = (n/2) - 1;     int columnStart_a = 0;        int columnEnd_a = (n/2) - 1;
        int rowStart_b = 0;     int rowEnd_b = (n/2) - 1;     int columnStart_b = n/2;      int columnEnd_b = n - 1;
        int rowStart_c = n/2;   int rowEnd_c = n - 1;         int columnStart_c = 0;        int columnEnd_c = (n/2) - 1;
        int rowStart_d = n/2;   int rowEnd_d = n - 1;         int columnStart_d = n/2;      int columnEnd_d = n - 1;

        int rowStart_e = 0;     int rowEnd_e = (n/2) - 1;     int columnStart_e = 0;        int columnEnd_e = (n/2) - 1;
        int rowStart_f = 0;     int rowEnd_f = (n/2) - 1;     int columnStart_f = n/2;      int columnEnd_f = n - 1;
        int rowStart_g = n/2;   int rowEnd_g = n - 1;         int columnStart_g = 0;        int columnEnd_g = (n/2) - 1;
        int rowStart_h = n/2;   int rowEnd_h = n - 1;         int columnStart_h = n/2;      int columnEnd_h = n - 1;

        copyMatrix_Part2Complete(n / 2, matrix_a, n, matrixA, rowStart_a, rowEnd_a, columnStart_a, columnEnd_a);
        copyMatrix_Part2Complete(n / 2, matrix_b, n, matrixA, rowStart_b, rowEnd_b, columnStart_b, columnEnd_b);
        copyMatrix_Part2Complete(n / 2, matrix_c, n, matrixA, rowStart_c, rowEnd_c, columnStart_c, columnEnd_c);
        copyMatrix_Part2Complete(n / 2, matrix_d, n, matrixA, rowStart_d, rowEnd_d, columnStart_d, columnEnd_d);

        copyMatrix_Part2Complete(n / 2, matrix_e, n, matrixB, rowStart_e, rowEnd_e, columnStart_e, columnEnd_e);
        copyMatrix_Part2Complete(n / 2, matrix_f, n, matrixB, rowStart_f, rowEnd_f, columnStart_f, columnEnd_f);
        copyMatrix_Part2Complete(n / 2, matrix_g, n, matrixB, rowStart_g, rowEnd_g, columnStart_g, columnEnd_g);
        copyMatrix_Part2Complete(n / 2, matrix_h, n, matrixB, rowStart_h, rowEnd_h, columnStart_h, columnEnd_h);

        //conquer
        /*
        //conquer for divide and conquer algorithm
        int matrix_r[n / 2][n / 2];
        int matrix_s[n / 2][n / 2];
        int matrix_t[n / 2][n / 2];
        int matrix_u[n / 2][n / 2];

        int matrix_ae[n / 2][n / 2];
        int matrix_bg[n / 2][n / 2];
        int matrix_af[n / 2][n / 2];
        int matrix_bh[n / 2][n / 2];
        int matrix_ce[n / 2][n / 2];
        int matrix_dg[n / 2][n / 2];
        int matrix_cf[n / 2][n / 2];
        int matrix_dh[n / 2][n / 2];

        multiplySquareMatrices(n / 2, matrix_ae, matrix_a, matrix_e);
        multiplySquareMatrices(n / 2, matrix_bg, matrix_b, matrix_g);
        multiplySquareMatrices(n / 2, matrix_af, matrix_a, matrix_f);
        multiplySquareMatrices(n / 2, matrix_bh, matrix_b, matrix_h);
        multiplySquareMatrices(n / 2, matrix_ce, matrix_c, matrix_e);
        multiplySquareMatrices(n / 2, matrix_dg, matrix_d, matrix_g);
        multiplySquareMatrices(n / 2, matrix_cf, matrix_c, matrix_f);
        multiplySquareMatrices(n / 2, matrix_dh, matrix_d, matrix_h);

        addSquareMatrices(n / 2, matrix_r, matrix_ae, matrix_bg);
        addSquareMatrices(n / 2, matrix_s, matrix_af, matrix_bh);
        addSquareMatrices(n / 2, matrix_t, matrix_ce, matrix_dg);
        addSquareMatrices(n / 2, matrix_u, matrix_cf, matrix_dh);
        */

        //conquer for Strassen's algorithm
        int matrix_fSubH[n / 2][n / 2];
        int matrix_aAddB[n / 2][n / 2];
        int matrix_cAddD[n / 2][n / 2];
        int matrix_gSubE[n / 2][n / 2];
        int matrix_aAddD[n / 2][n / 2];
        int matrix_eAddH[n / 2][n / 2];
        int matrix_bSubD[n / 2][n / 2];
        int matrix_gAddH[n / 2][n / 2];
        int matrix_aSubC[n / 2][n / 2];
        int matrix_eAddF[n / 2][n / 2];

        subSquareMatrices(n / 2, matrix_fSubH, matrix_f, matrix_h); //S1 = B12 - B22
        addSquareMatrices(n / 2, matrix_aAddB, matrix_a, matrix_b); //S2 = A11 + A12
        addSquareMatrices(n / 2, matrix_cAddD, matrix_c, matrix_d); //S3 = A21 + A22
        subSquareMatrices(n / 2, matrix_gSubE, matrix_g, matrix_e); //S4 = B21 - B11
        addSquareMatrices(n / 2, matrix_aAddD, matrix_a, matrix_d); //S5 = A11 + A22
        addSquareMatrices(n / 2, matrix_eAddH, matrix_e, matrix_h); //S6 = B11 + B22
        subSquareMatrices(n / 2, matrix_bSubD, matrix_b, matrix_d); //S7 = A12 - A22
        addSquareMatrices(n / 2, matrix_gAddH, matrix_g, matrix_h); //S8 = B21 + B22
        subSquareMatrices(n / 2, matrix_aSubC, matrix_a, matrix_c); //S9 = A11 - A21
        addSquareMatrices(n / 2, matrix_eAddF, matrix_e, matrix_f); //S10= B11 + B12

        int matrix_p1[n / 2][n / 2];
        int matrix_p2[n / 2][n / 2];
        int matrix_p3[n / 2][n / 2];
        int matrix_p4[n / 2][n / 2];
        int matrix_p5[n / 2][n / 2];
        int matrix_p6[n / 2][n / 2];
        int matrix_p7[n / 2][n / 2];

        multiplySquareMatrices(n / 2, matrix_p1, matrix_a,     matrix_fSubH); //P1 = A11*S1
        multiplySquareMatrices(n / 2, matrix_p2, matrix_aAddB, matrix_h);     //P2 = S2*B22
        multiplySquareMatrices(n / 2, matrix_p3, matrix_cAddD, matrix_e);     //P3 = S3*B11
        multiplySquareMatrices(n / 2, matrix_p4, matrix_d,     matrix_gSubE); //P4 = A22*S4
        multiplySquareMatrices(n / 2, matrix_p5, matrix_aAddD, matrix_eAddH); //P5 = S5*S6
        multiplySquareMatrices(n / 2, matrix_p6, matrix_bSubD, matrix_gAddH); //P6 = S7*S8
        multiplySquareMatrices(n / 2, matrix_p7, matrix_aSubC, matrix_eAddF); //P7 = S9*S10

        int matrix_r[n / 2][n / 2];
        int matrix_s[n / 2][n / 2];
        int matrix_t[n / 2][n / 2];
        int matrix_u[n / 2][n / 2];

        int matrix_rTemp1[n / 2][n / 2];
        int matrix_rTemp2[n / 2][n / 2];
        int matrix_uTemp1[n / 2][n / 2];
        int matrix_uTemp2[n / 2][n / 2];

        addSquareMatrices(n / 2, matrix_rTemp1,        matrix_p5,            matrix_p4); //P5 + P4
        subSquareMatrices(n / 2, matrix_rTemp2,        matrix_rTemp1,        matrix_p2); //(P5 + P4) - P2
        addSquareMatrices(n / 2, matrix_r,             matrix_rTemp2,        matrix_p6); //C11 = (P5 + P4 - P2) + P6

        addSquareMatrices(n / 2, matrix_s,             matrix_p1,            matrix_p2); //C12 = P1 + P2

        addSquareMatrices(n / 2, matrix_t,             matrix_p3,            matrix_p4); //C21 = P3 + P4

        addSquareMatrices(n / 2, matrix_uTemp1,        matrix_p5,            matrix_p1); //P5 + P1
        subSquareMatrices(n / 2, matrix_uTemp2,        matrix_uTemp1,        matrix_p3); //(P5 + P1) - P3
        subSquareMatrices(n / 2, matrix_u,             matrix_uTemp2,        matrix_p7); //C22 = (P5 + P1 - P3) - P7

        //combine
        int rowStart_r = 0;     int rowEnd_r = (n/2) - 1;     int columnStart_r = 0;        int columnEnd_r = (n/2) - 1;
        int rowStart_s = 0;     int rowEnd_s = (n/2) - 1;     int columnStart_s = n/2;      int columnEnd_s = n - 1;
        int rowStart_t = n/2;   int rowEnd_t = n - 1;         int columnStart_t = 0;        int columnEnd_t = (n/2) - 1;
        int rowStart_u = n/2;   int rowEnd_u = n - 1;         int columnStart_u = n/2;      int columnEnd_u = n - 1;

        copyMatrix_Complete2Part(n, matrixAB, rowStart_r, rowEnd_r, columnStart_r, columnEnd_r, (n / 2), matrix_r);
        copyMatrix_Complete2Part(n, matrixAB, rowStart_s, rowEnd_s, columnStart_s, columnEnd_s, (n / 2), matrix_s);
        copyMatrix_Complete2Part(n, matrixAB, rowStart_t, rowEnd_t, columnStart_t, columnEnd_t, (n / 2), matrix_t);
        copyMatrix_Complete2Part(n, matrixAB, rowStart_u, rowEnd_u, columnStart_u, columnEnd_u, (n / 2), matrix_u);

    }
}

void copyMatrix_Part2Complete(int destN, int destMatrix[][destN], int srcN, 
int srcMatrix[][srcN], int rowS, int rowE, int colS, int colE)
{
    int i; int j;
    int k; int l;

    for(i = rowS, k = 0; k < destN; i++, k++)
    {
        for(j = colS, l = 0; l < destN; j++, l++)
        {
            destMatrix[k][l] = srcMatrix[i][j];
        }
    }
}

void addSquareMatrices(int n, int addedMatrix[][n], int matrixA[][n], int matrixB[][n])
{
    int i; int j;
    for(i = 0; i < n; i++)
    {
        for(j = 0; j < n; j++)
        {
            addedMatrix[i][j] = matrixA[i][j] + matrixB[i][j];
        }
    }
}

void subSquareMatrices(int n, int subtractedMatrix[][n], int matrixA[][n], int matrixB[][n])
{
    int i; int j;
    for(i = 0; j < n; i++)
    {
        for(j = 0; j < n; j++)
        {
            subtractedMatrix[i][j] = matrixA[i][j] - matrixB[i][j];
        }
    }
}

void copyMatrix_Complete2Part(int destN, int destMatrix[][destN], int rowStart, int rowEnd, int columnStart, int columnEnd, int srcN, int srcMatrix[][srcN])
{
    int i; int j;
    int k; int l;

    for(i = 0, k = rowStart; k <= rowEnd; i++, k++)
    {
        for(j = 0, l = columnStart; l <= columnEnd; j++, l++)
        {
            destMatrix[k][l] = srcMatrix[i][j];
        }
    }
}

void showSquareMatrix(int n, int matrix[][n])
{
    int i;
    int j;
    for(i = 0; i < n; i++)
    {
        printf("\n");
        for(j = 0; j < n; j++)
        {
            printf("\t%d", matrix[i][j]);
        }
    }
}

在recurseve功能的征服部分&#34; multiplySquareMatrices&#34;,可能会有一些问题。

如果我替换了Strassen&#34;算法&#34;征服?部分代码用&#34;征服Divide&amp;征服算法&#34;通过消除“征服分裂&amp;征服算法&#34;部分&amp;评论&#34;征服斯特拉森的算法&#34;部分,代码工作。我不明白为什么?

1 个答案:

答案 0 :(得分:3)

征服分裂&amp;征服算法算法不使用 subSquareMatrices函数。现在看看这个函数的定义

int i; int j;
for(i = 0; j < n; i++)  // <-------
{
    for(j = 0; j < n; j++)
    {
        subtractedMatrix[i][j] = matrixA[i][j] - matrixB[i][j];
    }
}

j应由外部for循环中的i替换。