Strassen乘法 - c程序

时间:2016-01-21 14:13:11

标签: c algorithm matrix-multiplication double-pointer

我们必须在实际会话中实现strassen的乘法,我写的代码如下所示,你可以看到我使用了很多中间矩阵。我想知道如何返回一个二维数组一个函数,使代码看起来更干净,更容易理解,它也会给我一些关于指针的见解(对我来说是一个弱区),即说是使用双指针作为子函数的返回类型(int ** sub(args list) )) 因为我的strassen函数有原型strassen(int,int [],int [] ..) 当strassen函数的一个参数是子函数的结果时,我得到一个错误,说明期望int(*)[]但返回int ** 为了解决这个问题,我使用int(*)[]对子函数的结果进行了类型化处理,但它没有按预期工作 请帮忙 ?谢谢!

#include<stdio.h>
#include<stdlib.h>

void add(int n,int a[n][n],int b[n][n],int result[][n])
{
printf("---add---\n");
int i,j;
//int **result = (int **)malloc(n*sizeof(int *));

/*for(i=0;i<n;i++)
    result[i] = (int *)malloc(n*sizeof(int));*/

for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        result[i][j] = a[i][j] + b[i][j];
        printf("%d\t",result[i][j]);
    }
    printf("\n");
}
//return result;
}

   void sub(int n,int a[n][n],int b[n][n],int result[][n])
  {
printf("---sub---\n");
int i,j;
/*int **result = (int **)malloc(n*sizeof(int *));

for(i=0;i<n;i++)
    result[i] = (int *)malloc(n*sizeof(int));*/

for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        result[i][j] = a[i][j] - b[i][j];
        printf("%d\t",result[i][j]);
    }
}

}

 void divide(int n,int a[n][n],int c[n/2][n/2],int i,int j)
 {
int i1,i2,j1,j2;
for(i1=0,i2=i;i1<n/2;i1++,i2++)
{
    for(j1=0,j2=j;j1<n/2;j1++,j2++)
    {
        c[i1][j1] = a[i2][j2];
    }
}
}

  void join(int n,int a[][n],int c[][n/2],int i,int j)      
 {
printf("join\n");
int i1,i2,j1,j2;
for(i1=0,i2=i;i1<(n/2);i1++,i2++)
{
    for(j1=0,j2=j;j1<(n/2);j1++,j2++)
    {
        a[i2][j2] = c[i1][j1];
        printf("c[%d][%d] %d\n",i1,j1,c[i1][j1]);
    }
}
}

 void multiply(int n,int a[][n],int b[][n],int result[][n])
{
int i,j;

if(n==2)
{
    //partial products
    printf("base case\n");
    int p1 = (a[0][0]+a[1][1])*(b[0][0]+b[1][1]);
    int p2 = (a[1][0]+a[1][1])*b[0][0];
    int p3 = a[0][0]*(b[0][1]-b[1][1]);
    int p4 = a[1][1]*(b[1][0]-b[0][0]);
    int p5 = (a[0][0]+a[0][1])*b[1][1];
    int p6 = (a[1][0]-a[0][0])*(b[0][0]+b[0][1]);
    int p7 = (a[0][1]-a[1][1])*(b[1][0]+b[1][1]);

    int c11 = p1 + p4 - p5 + p7;
    int c12 = p3 + p5;
    int c21 = p2 + p4;
    int c22 = p1 + p3 - p2 + p6;

    result[0][0] = c11;
    result[0][1] = c12;
    result[1][0] = c21;
    result[1][1] = c22;

    for(i=0;i<2;i++)
    {
        for(j=0;j<2;j++)
        {
            printf("%d\t",result[i][j]);
        }
        printf("\n");
    }

}

else
{
    int a11[n/2][n/2];
    int a12[n/2][n/2];
    int a21[n/2][n/2];
    int a22[n/2][n/2];

    int b11[n/2][n/2];
    int b12[n/2][n/2];
    int b21[n/2][n/2];
    int b22[n/2][n/2];

    //divide matrices A & B into four parts
    divide(n,a,a11,0,0);
    divide(n,a,a12,0,n/2);
    divide(n,a,a21,n/2,0);
    divide(n,a,a22,n/2,n/2);

    divide(n,b,b11,0,0);
    divide(n,b,b12,0,n/2);
    divide(n,b,b21,n/2,0);
    divide(n,b,b22,n/2,n/2);

    //partial products

    int p1[n/2][n/2],p2[n/2][n/2],p3[n/2][n/2],p4[n/2][n/2],p5[n/2][n/2],p6[n/2][n/2],p7[n/2][n/2];

    int c11[n/2][n/2],c12[n/2][n/2],c21[n/2][n/2],c22[n/2][n/2];

    int i1[n/2][n/2],i2[n/2][n/2];

    add(n/2,a11,a22,i1);
    add(n/2,b11,b22,i2);
    multiply(n/2,i1,i2,p1);

    int i3[n/2][n/2];
    add(n/2,a21,a22,i3);
    multiply(n/2,i3,b11,p2);

    int i4[n/2][n/2];
    sub(n/2,b12,b22,i4);
    multiply(n/2,a11,i4,p3);

    int i5[n/2][n/2];
    sub(n/2,b21,b11,i5);
    multiply(n/2,a22,i5,p4);

    int i6[n/2][n/2];
    add(n/2,a11,a12,i6);
    multiply(n/2,i6,b22,p5);

    int i7[n/2][n/2];
    int i8[n/2][n/2];
    sub(n/2,a21,a11,i7);
    add(n/2,b11,b12,i8);
    multiply(n/2,i7,i8,p6);

    int i9[n/2][n/2];
    int i10[n/2][n/2];

    sub(n/2,a12,a22,i9);
    add(n/2,b21,b22,i10);
    multiply(n/2,i9,i10,p7);

    //for c11
    int r1[n/2][n/2];
    int r2[n/2][n/2];

    add(n/2,p1,p4,r1);        //sub operation
    sub(n/2,r1,p5,r2);        //sub operation
    add(n/2,r2,p7,c11);        //main operation

    //for c12
    add(n/2,p3,p5,c12);

    //for c21
    add(n/2,p2,p4,c21);

    //for c22
    int r3[n/2][n/2];
    int r4[n/2][n/2];

    add(n/2,p1,p3,r3);       //sub operation
    sub(n/2,r3,p2,r4);       //sub operation
    add(n/2,r4,p6,c22);       //main operation


    join(n,result,c11,0,0);
    join(n,result,c12,0,n/2);
    join(n,result,c21,n/2,0);
    join(n,result,c22,n/2,n/2);

    printf("---c11---\n");
    for(i=0;i<n/2;i++)
    {
        for(j=0;j<n/2;j++)
        {
            printf("%d\t",c11[i][j]);
        }
        printf("\n");
    }
    printf("---c12---\n");
    for(i=0;i<n/2;i++)
    {
        for(j=0;j<n/2;j++)
        {
            printf("%d\t",c12[i][j]);
        }
        printf("\n");
    }
    printf("---c21---\n");
    for(i=0;i<n/2;i++)
    {
        for(j=0;j<n/2;j++)
        {
            printf("%d\t",c21[i][j]);
        }
        printf("\n");
    }
     printf("---c22---\n");
    for(i=0;i<n/2;i++)
    {
        for(j=0;j<n/2;j++)
        {
            printf("%d\t",c22[i][j]);
        }
        printf("\n");
    }
    /*for(i=0;i<n;i++)
    {
        for(j=0;j<n;j++)
        {
            printf("%d\t",result[i][j]);
        }
        printf("\n");
    }*/

}

}

 int main()
 {
int n;
printf("Enter the order of the matrices(power of 2)\n");
scanf("%d",&n);

int i,j;
int a[n][n],b[n][n];

printf("Enter first matrix\n");
for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        scanf("%d",&a[i][j]);
    }
}
printf("Enter second matrix\n");
for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        scanf("%d",&b[i][j]);
    }
}
printf("First matrix is \n");
for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        printf("%d\t",a[i][j]);
    }
    printf("\n");
}

printf("Second matrix is \n");
for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        printf("%d\t",b[i][j]);
    }
    printf("\n");
}

int r[n][n];
multiply(n,a,b,r);
printf("---RESULT OF MULTIPLICATION---\n");
for(i=0;i<n;i++)
{
    for(j=0;j<n;j++)
    {
        printf("%d\t",r[i][j]);
    }
    printf("\n");
}

return 0;

}

0 个答案:

没有答案