使用SSE实现高斯消除

时间:2015-11-02 00:19:25

标签: c x86 linear-algebra sse simd

我尝试使用SSE实现高斯消除,但我认为我的对齐可能已关闭,或者我可能以错误的顺序执行某些块。

作为一个例子,这里是输入矩阵的第一个8x8子矩阵,使用串行实现计算的矩阵,以及不正确的输出矩阵:

输入矩阵:

50.000000 15.000000 44.000000 18.000000 15.000000 21.000000 32.000000 6.000000 
35.000000 39.000000 26.000000 44.000000 8.000000 7.000000 24.000000 11.000000 
36.000000 21.000000 45.000000 15.000000 17.000000 31.000000 48.000000 9.000000 
33.000000 15.000000 13.000000 41.000000 29.000000 41.000000 22.000000 30.000000 
46.000000 19.000000 35.000000 37.000000 32.000000 17.000000 29.000000 43.000000 
42.000000 11.000000 23.000000 31.000000 31.000000 6.000000 42.000000 22.000000 
40.000000 34.000000 21.000000 8.000000 14.000000 7.000000 47.000000 14.000000 
7.000000 27.000000 33.000000 17.000000 4.000000 37.000000 11.000000 43.000000 

参考矩阵:

1.000000 0.300000 0.880000 0.360000 0.300000 0.420000 0.640000 0.120000 
0.000000 1.000000 -0.168421 1.101754 -0.087719 -0.270175 0.056140 0.238596 
0.000000 0.000000 1.000000 -0.611648 0.471791 1.239255 1.621728 0.149377 
0.000000 0.000000 0.000000 1.000000 1.878897 3.329519 1.773631 1.905714 
0.000000 0.000000 0.000000 0.000000 1.000000 22.894316 9.444982 -9.377328 
0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.259213 -0.374202 
0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 -0.914501 
0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 

输出矩阵不正确:

1.000000 0.300000 0.880000 0.360000 0.300000 0.420000 0.640000 0.120000 
0.000000 1.000000 -0.168421 1.101754 -0.087719 -0.270175 0.056140 0.238596 
0.000000 0.000000 1.000000 -0.611648 0.471791 1.239255 1.621728 0.149377 
0.000000 0.000000 0.000000 1.000000 1.878897 3.329519 1.773631 1.905714 
0.000000 0.000000 0.000000 0.000000 -1.520596 -34.812996 -14.361997 14.259123 
0.000000 0.000000 0.000000 0.000000 0.000000 260.456787 139.866196 -114.162613 
0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 -399578.125000 326107.625000 
0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 -35436253184.000000 

以下是功能本身:

void 
gauss_eliminate_using_sse(const Matrix A, Matrix U)
{
    // iterators
    unsigned int i, j, k;
    // shorthand for matrix elements
    float *a_el = A.elements;
    float *u_el = U.elements;

    // load A matrix into U matrix
    for (i = 0; i < MATRIX_SIZE; ++i){
        for(j = 0; j < MATRIX_SIZE/4; ++j){
            // copy four elements at a time
            __m128 buffer = _mm_load_ps(&a_el[MATRIX_SIZE*i + j*4]);
            _mm_store_ps(&u_el[MATRIX_SIZE*i + j*4], buffer); 
        }
    }

    // for each pivot in the matrix
    for (k = 0; k < MATRIX_SIZE; ++k){
        // find the pivot at U[k][k]
        float pivot = u_el[MATRIX_SIZE*k + k];
        // ensure matrix stability
        if (pivot == 0)
          exit(EXIT_FAILURE);

        // load pivot into all four sections of register
        __m128 m_pivot = _mm_set1_ps(pivot);
        __m128 buffer;

        // Division step
        // Beginning with the block containing u[k][k], divide each four-word block by the pivot
        for (j = k/4*4; j < MATRIX_SIZE/4; ++j){
            buffer = _mm_load_ps(&u_el[MATRIX_SIZE*k + j*4]);
            buffer = _mm_div_ps(buffer, m_pivot);
            _mm_store_ps(&u_el[MATRIX_SIZE*k + j*4], buffer);
        }

        // Elimination step
        // Iterating over each row
        for (i = (k+1); i < MATRIX_SIZE; ++i){
            // If in one of the last four blocks, a four-word block cannot be created. Process serially.
            if (i > MATRIX_SIZE - 4) {
                for (j = (k+1); j < MATRIX_SIZE; ++j)
                    u_el[MATRIX_SIZE * i + j] = u_el[MATRIX_SIZE * i + j] - (u_el[MATRIX_SIZE * i + k] * u_el[MATRIX_SIZE * k + j]);
            } else {
                // If u[i][k+1] is not aligned on a four-word block, process serially until reaching an index that is
                int serial_process_count = ((k+1) % 4 == 0) ? 0 : 4 - ((k+1)%4);
                for (j = (k+1); j < (k+1+serial_process_count); ++j)
                    u_el[MATRIX_SIZE * i + j] = u_el[MATRIX_SIZE * i + j] - (u_el[MATRIX_SIZE * i + k] * u_el[MATRIX_SIZE * k + j]);

                // Iterate over each four-word block, beginning at the index reached by lines 158-161
                __m128 m0, m1, m2, m3, m4;
                // fetch U[MATRIX_SIZE * i + k], placing the same word in each index
                m1 = _mm_load1_ps(&u_el[MATRIX_SIZE * i + k]);
                for (j = (k+1+serial_process_count); j < MATRIX_SIZE; j+=4){
                    // fetch U[MATRIX_SIZE * i + j + n], where n = 0..3
                    m0 = _mm_load_ps(&u_el[MATRIX_SIZE * i + j]);
                    // fetch U[MATRIX_SIZE * k + j + n], where n = 0..3
                    m2 = _mm_load_ps(&u_el[MATRIX_SIZE * k + j]);

                    // U[MATRIX_SIZE * i + k] * U[MATRIX_SIZE * k + j]
                    m3 = _mm_mul_ps(m1, m2);
                    // U[MATRIX_SIZE * i - j] - m1
                    m4 = _mm_sub_ps(m0, m3);
                    // U[MATRIX_SIZE * i - j] = m0
                    _mm_store_ps(&u_el[MATRIX_SIZE * i + j], m4);
                }
             }
             u_el[MATRIX_SIZE * i + k] = 0;
        }
    }
}

任何帮助调试都将不胜感激。

1 个答案:

答案 0 :(得分:3)

我通过修复循环边界解决了这个问题。起初,我试图在一个可被4整除的索引处开始每个循环,如下所示:

j = k/4*4

并且假设我可以重新处理元素。我应该做这样的事情:

// Division Step
// Determine the number of elements that must be processed serially
// so that the remainder can be processed in groups of 4
int serial_process_count = (MATRIX_SIZE-(k+1)) % 4;

// for each element found, divide by the pivot serially
if (serial_process_count >= 1)
    u_el[MATRIX_SIZE*k + k + 1] = u_el[MATRIX_SIZE*k + k + 1]/pivot;
if (serial_process_count >= 2)
    u_el[MATRIX_SIZE*k + k + 2] = u_el[MATRIX_SIZE*k + k + 2]/pivot;
if (serial_process_count >= 3)
    u_el[MATRIX_SIZE*k + k + 3] = u_el[MATRIX_SIZE*k + k + 3]/pivot;

// for the remaining elements, divide by the pivot in groups of four
for (j = (k+1+serial_process_count); j < MATRIX_SIZE; j+=4){
    buffer = _mm_load_ps(&u_el[MATRIX_SIZE*k + j]);
    buffer = _mm_div_ps(buffer, m_pivot);
    _mm_store_ps(&u_el[MATRIX_SIZE*k + j], buffer);
}

分区和淘汰步骤。