这组嵌套for循环适用于M = 64和N = 64的值,但是当我使M = 128且N = 64时不起作用。我有另一个程序,检查矩阵乘法的正确值。直观地看起来它应该仍然有效,但给了我错误的答案。
for(int m=64;m<=M;m+=64){
for(int n=64;n<=N;n+=64){
for(int i = m-64; i < m; i+=16){
float *A_column_start, *C_column_start;
__m128 c_1, c_2, c_3, c_4, a_1, a_2, a_3, a_4, mul_1,
mul_2, mul_3, mul_4, b_1;
int j, k;
for(j = m-64; j < m; j++){
//Load 16 contiguous column aligned elements from matrix C in
//c_1-c_4 registers
C_column_start = C+i+j*M;
c_1 = _mm_loadu_ps(C_column_start);
c_2 = _mm_loadu_ps(C_column_start+4);
c_3 = _mm_loadu_ps(C_column_start+8);
c_4 = _mm_loadu_ps(C_column_start+12);
for (k=n-64; k < n; k+=2){
//Load 16 contiguous column aligned elements from matrix A to
//the a_1-a_4 registers
A_column_start = A+k*M;
a_1 = _mm_loadu_ps(A_column_start+i);
a_2 = _mm_loadu_ps(A_column_start+i+4);
a_3 = _mm_loadu_ps(A_column_start+i+8);
a_4 = _mm_loadu_ps(A_column_start+i+12);
//Load a value to resgister b_1 to act as a "B" or ("A^T")
//element to multiply against the A matrix
b_1 = _mm_load1_ps(A_column_start+j);
mul_1 = _mm_mul_ps(a_1, b_1);
mul_2 = _mm_mul_ps(a_2, b_1);
mul_3 = _mm_mul_ps(a_3, b_1);
mul_4 = _mm_mul_ps(a_4, b_1);
//Add together all values of the multiplied A and "B"
//(or "A^T") matrix elements
c_4 = _mm_add_ps(c_4, mul_4);
c_3 = _mm_add_ps(c_3, mul_3);
c_2 = _mm_add_ps(c_2, mul_2);
c_1 = _mm_add_ps(c_1, mul_1);
//Move over one column in A, and load the next 16 contiguous
//column aligned elements from matrix A to the a_1-a_4 registers
A_column_start+=M;
a_1 = _mm_loadu_ps(A_column_start+i);
a_2 = _mm_loadu_ps(A_column_start+i+4);
a_3 = _mm_loadu_ps(A_column_start+i+8);
a_4 = _mm_loadu_ps(A_column_start+i+12);
//Load a value to resgister b_1 to act as a "B" or "A^T"
//element to multiply against the A matrix
b_1 = _mm_load1_ps(A_column_start+j);
mul_1 = _mm_mul_ps(a_1, b_1);
mul_2 = _mm_mul_ps(a_2, b_1);
mul_3 = _mm_mul_ps(a_3, b_1);
mul_4 = _mm_mul_ps(a_4, b_1);
//Add together all values of the multiplied A and "B" or
//("A^T") matrix elements
c_4 = _mm_add_ps(c_4, mul_4);
c_3 = _mm_add_ps(c_3, mul_3);
c_2 = _mm_add_ps(c_2, mul_2);
c_1 = _mm_add_ps(c_1, mul_1);
}
//Store the added up C values back to memory
_mm_storeu_ps(C_column_start, c_1);
_mm_storeu_ps(C_column_start+4, c_2);
_mm_storeu_ps(C_column_start+8, c_3);
_mm_storeu_ps(C_column_start+12, c_4);
}
}
}
}}
答案 0 :(得分:2)
我猜您在代码中使用了M
C_column_start = C+i+j*M;
需要使用m
代替。也可能在您使用M
的其他行中。
但是,我并不真正理解你的代码,因为你没有解释代码的目的是什么,而且我不是数学程序员。
答案 1 :(得分:0)
它适用于M = 64和N = 64,因为在这些情况下,您只在相应的循环中进行一次迭代(最外面的两个)。如果你有M = 128,你现在在外循环上做两步,在这种情况下行
C_column_start = C+i+j*M;
和行
A_column_start = A+k*M;
对于内循环将产生相同的结果,因此对于在外循环上执行的两个步骤(m = 64,128),您只需将m = 128的一步结果加倍。修复就像将M更改为m一样简单,以便您使用迭代变量。
此外,您应该考虑在A和C中对齐数据,以便可以执行SSE对齐的加载。这将导致更快的代码。