如何使用这个C代码使用Strassen算法乘以两个矩阵?

时间:2012-03-02 19:19:04

标签: c matrix matrix-multiplication strassen

我在C中寻找Strassen's Algorithm的实现,最后我发现了这段代码。

使用multiply功能:

void multiply(int n, matrix a, matrix b, matrix c, matrix d);

将两个矩阵ab相乘并将结果放入cd是中间矩阵)。矩阵ab应具有以下类型:

typedef union _matrix 
{
    double **d;
    union _matrix **p;
} *matrix;

我动态分配了四个矩阵abcd(二维二维数组)并将其地址分配给字段{{ 1}}:

_matrix.d

此代码已成功编译,但与#include "strassen.h" #define SIZE 50 int main(int argc, char *argv[]) { double ** matA, ** matB, ** matC, ** matD; union _matrix ma, mb, mc, md; int i = 0, j = 0, n; matA = (double **) malloc(sizeof(double *) * SIZE); for (i = 0; i < SIZE; i++) matA[i] = (double *) malloc(sizeof(double) * SIZE); // Do the same for matB, matC, matD. ma.d = matA; mb.d = matB; mc.d = matC; md.d = matD; // Initialize matC and matD to 0. // Read n. // Read matA and matB. multiply(n, &ma, &mb, &mc, &md); return 0; } &gt;崩溃n

strassen.c:

BREAK

strassen.h:

#include "strassen.h"

/* c = a * b */
void multiply(int n, matrix a, matrix b, matrix c, matrix d)
{
    if (n <= BREAK) {
      double sum, **p = a->d, **q = b->d, **r = c->d;
      int i, j, k;

      for (i = 0; i < n; i++)
         for (j = 0; j < n; j++) {
            for (sum = 0., k = 0; k < n; k++)
               sum += p[i][k] * q[k][j];
            r[i][j] = sum;
         }
    } else {
        n /= 2;
        sub(n, a12, a22, d11);
        add(n, b21, b22, d12);
        multiply(n, d11, d12, c11, d21);
        sub(n, a21, a11, d11);
        add(n, b11, b12, d12);
        multiply(n, d11, d12, c22, d21);
        add(n, a11, a12, d11);
        multiply(n, d11, b22, c12, d12);
        sub(n, c11, c12, c11);
        sub(n, b21, b11, d11);
        multiply(n, a22, d11, c21, d12);
        add(n, c21, c11, c11);
        sub(n, b12, b22, d11);
        multiply(n, a11, d11, d12, d21);
        add(n, d12, c12, c12);
        add(n, d12, c22, c22);
        add(n, a21, a22, d11);
        multiply(n, d11, b11, d12, d21);
        add(n, d12, c21, c21);
        sub(n, c22, d12, c22);
        add(n, a11, a22, d11);
        add(n, b11, b22, d12);
        multiply(n, d11, d12, d21, d22);
        add(n, d21, c11, c11);
        add(n, d21, c22, c22);
    }
}

/* c = a + b */
void add(int n, matrix a, matrix b, matrix c)
{
    if (n <= BREAK) {
        double **p = a->d, **q = b->d, **r = c->d;
        int i, j;

        for (i = 0; i < n; i++)
           for (j = 0; j < n; j++)
              r[i][j] = p[i][j] + q[i][j];
    } else {
        n /= 2;
        add(n, a11, b11, c11);
        add(n, a12, b12, c12);
        add(n, a21, b21, c21);
        add(n, a22, b22, c22);
    }
}

/* c = a - b */
void sub(int n, matrix a, matrix b, matrix c)
{
    if (n <= BREAK) {
        double **p = a->d, **q = b->d, **r = c->d;
        int i, j;

        for (i = 0; i < n; i++)
           for (j = 0; j < n; j++)
              r[i][j] = p[i][j] - q[i][j];
    } else {
        n /= 2;
        sub(n, a11, b11, c11);
        sub(n, a12, b12, c12);
        sub(n, a21, b21, c21);
        sub(n, a22, b22, c22);
    }
}

我的问题是如何使用函数#define BREAK 8 typedef union _matrix { double **d; union _matrix **p; } *matrix; /* Notational shorthand to access submatrices for matrices named a, b, c, d */ #define a11 a->p[0] #define a12 a->p[1] #define a21 a->p[2] #define a22 a->p[3] #define b11 b->p[0] #define b12 b->p[1] #define b21 b->p[2] #define b22 b->p[3] #define c11 c->p[0] #define c12 c->p[1] #define c21 c->p[2] #define c22 c->p[3] #define d11 d->p[0] #define d12 d->p[1] #define d21 d->p[2] #define d22 d->p[3] (如何实现矩阵)。

strassen.h

strassen.c

4 个答案:

答案 0 :(得分:3)

Atom said一样,您需要为两个矩阵正确初始化matrix.p

1)首先,您的matrixunion,因此p基本上被d解释为_matrix **,这在这里没有意义 - 这就是崩溃的原因。您可能需要将matrix改为struct 最后,根据定义,p是一个子矩阵数组,因此它应该是struct _matrix *(并且在需要时您需要malloc实际数组)或struct _matrix[4](其中是不可能的:))。

typedef struct _matrix 
{
    double **d;
    struct _matrix *p;
} *matrix;

2)现在,让我们看看p应该是什么。

                           │
A.d ->  d1 -> a[1,1] a[1,2]│a[1,3] a[1,4]
        d2 -> a[2,1] a[2,2]│a[2,3] a[2,4]
             ─────────────────────────────
        d3 -> a[3,1] a[3,2]│a[3,3] a[3,4]
        d4 -> a[4,1] a[4,2]│a[4,3] a[4,4]
                           │

p指向matrix结构数组!特殊性是使这些结构的d指向 A 内部 以这种方式(p[k].d)[i][j]是相应的子矩阵的元素:

p[0].d -> p01 -> a[1,1]    p[1].d -> p11 -> a[1,3]
          p02 -> a[2,1]              p12 -> a[2,3]

p[2].d -> p21 -> a[3,1]    p[3].d -> p31 -> a[3,3]
          p22 -> a[4,1]              p32 -> a[4,3]

现在可以推断出算法是为任意偶数大小的 A 初始化p吗?

什么时候开始初始化它? ;)

答案 1 :(得分:2)

当n> BREAK,矩阵乘法算法使用分层矩阵表示(p的字段union _matrix,而不是字段d)。

在分配内存和初始化矩阵ab时,您需要调整分层表示的代码。

答案 2 :(得分:2)

好吧,我不确定你发布的代码有什么问题,但是我想指出你刚才在英特尔编程竞赛中使用的Strassen算法的实现:{{3} }。也许你可以在那里找到一些有用的提示。

答案 3 :(得分:0)

对于实施strassen算法以及@Tudor发布的document,请查看这个the code