我是MPI的新手。我需要在2D拓扑(网格)中制作矩阵乘法程序。第一矩阵(A)沿x坐标分布,第二矩阵(B)沿y坐标分布。每个过程都会计算一个子矩阵。我使用MPI_Bcast
在维度中发送子矩阵,但在该程序之后不会继续。我做错了什么?
这是代码。
#include<stdio.h>
#include<stdlib.h>
#include<mpi/mpi.h>
#define NUM_DIMS 2
#define N 81
#define A(i, j) A[N*(i)+(j)]
#define B(i, j) B[N*(i)+(j)]
#define C(i, j) C[N*(i)+(j)]
#define AA(i, j) AA[k *(i)+(j)] //
#define BB(i, j) BB[k*(i)+(j)]
#define CC(i, j) CC[k*(i)+(j)]
int main(int argc, char **argv) {
MPI_Init(&argc, &argv);
int threadCount;
int threadRank;
MPI_Comm_size(MPI_COMM_WORLD, &threadCount);
int dims[NUM_DIMS] = {0};
//Создаем решетку
int periods[2] = {0, 0};
MPI_Comm comm_2D;
MPI_Comm comm_1D[2];
MPI_Dims_create(threadCount, NUM_DIMS, dims);
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &comm_2D);
MPI_Comm_rank(comm_2D, &threadRank);
int k = N/dims[1];
double *A = (double*)calloc(N*N, sizeof(double));
double *B = (double*)calloc(N*N, sizeof(double));
double *C = (double*)calloc(N*N, sizeof(double));
double startTime = MPI_Wtime();
int subdims[2];
subdims[0] = 0;
subdims[1] = 1;
MPI_Cart_sub(comm_2D, subdims, &comm_1D[0]);
subdims[0] = 1;
subdims[1] = 0;
MPI_Cart_sub(comm_2D, subdims, &comm_1D[1]);
MPI_Datatype column, matrix;
MPI_Type_vector(N, N / k, N, MPI_DOUBLE, &column);
MPI_Type_create_resized(column, 0, N / k * sizeof(double), &column);
MPI_Type_commit(&column);
double *AA, *BB, *CC;
AA = (double*)calloc(N * k, sizeof(double));
BB = (double*)calloc(N * k, sizeof(double));
CC = (double*)calloc(k * k , sizeof(double));
int threadCoords[2];
MPI_Comm_rank(comm_2D, &threadRank);
MPI_Cart_coords(comm_2D, threadRank, NUM_DIMS, threadCoords);
if (threadCoords[0] == 0) {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < N; ++j) {
A(i, j) = 1;
B(i, j) = 1;
}
}
}
if (threadCoords[1] == 0) {
MPI_Scatter(A, N * k, MPI_DOUBLE, AA, N * k, MPI_DOUBLE, 0, comm_1D[0]);
}
if (threadCoords[0] == 0) {
int offset[3] = {0, 1, 2};
int send[3] = {1, 1, 1};
MPI_Scatterv(B, send, offset, column, BB, N * k , MPI_DOUBLE, 0, comm_1D[1]);
}
int r = MPI_Bcast(AA, k*N, MPI_DOUBLE, 0, comm_1D[1]);
fprintf(stderr, "r = %d\n", r);
int p = MPI_Bcast(BB, k*N, MPI_DOUBLE, 0, comm_1D[0]);
fprintf(stderr, "p = %d\n", p);
/*...*/
}