如何使用avx2内在函数优化chgemm(int = char * char)矩阵乘法?

时间:2016-09-13 12:11:31

标签: c matrix simd avx2

似乎很少有关于chgemm(int = char * char)矩阵乘法的讨论。假设M%8 = 0,N%8 = 0,K%8 = 0,B被转置。我记得支持AVX2的CPU只有16个ymm寄存器。所以我尝试实现2x8的阻塞矩阵以最大化使用寄存器。但是,我找不到更好的解决方案(例如,修改算法将pb的负载移动到外环)。我担心的另一个问题是总和减少的延迟(permute,sli,add)。 我也试过4x8和8x8,看起来8x8严重降低了性能。

有人可以帮我进一步优化此代码吗?谢谢!

void _chgemm_mm_u_c_N_T_2x8(
size_t M, size_t N, size_t K, float scaleAB, 
unsigned char *A, size_t lda, signed char *B, size_t ldb,
float scaleT, int *C, size_t ldc)
{
int h = M;
int w = N;
int d = K;
int i, j, k;
__m256i tmp_short = _mm256_set1_epi16(1);
for (i = 0; i < h; i += 2) {
    __m256i pc0, pc1, pc2, pc3;
    for (j = 0; j < w; j += 8 ) {
        unsigned char *pa0 = A + i * lda;
        unsigned char *pa1 = pa0 + 1*lda;

        signed char *pb0 = (signed char*)B + j*ldb;
        signed char *pb1 = pb0 + 1*ldb;
        signed char *pb2 = pb0 + 2*ldb;
        signed char *pb3 = pb0 + 3*ldb;
        signed char *pb4 = pb0 + 4*ldb;
        signed char *pb5 = pb0 + 5*ldb;
        signed char *pb6 = pb0 + 6*ldb;
        signed char *pb7 = pb0 + 7*ldb;

        int *pc = (int*)C + i * ldc + j;

        __m256i ma0, ma1; //ma2, ma3, ma4, ma5, ma6, ma7;
        __m256i mb0, mb1, mb2, mb3, mb4, mb5, mb6, mb7;
        __m256i mc0, mc1; //mc2, mc3, mc4, mc5, mc6, mc7;

        __m256i sum0 = _mm256_setzero_si256();
        __m256i sum1 = _mm256_setzero_si256();
        __m256i sum2 = _mm256_setzero_si256();
        __m256i sum3 = _mm256_setzero_si256();
        __m256i sum4 = _mm256_setzero_si256();
        __m256i sum5 = _mm256_setzero_si256();
        __m256i sum6 = _mm256_setzero_si256();
        __m256i sum7 = _mm256_setzero_si256();

        __m256i sum8 = _mm256_setzero_si256();
        __m256i sum9 = _mm256_setzero_si256();
        __m256i sum10 = _mm256_setzero_si256();
        __m256i sum11 = _mm256_setzero_si256();
        __m256i sum12 = _mm256_setzero_si256();
        __m256i sum13 = _mm256_setzero_si256();
        __m256i sum14 = _mm256_setzero_si256();
        __m256i sum15 = _mm256_setzero_si256();

        for (k = 0; k < d; k += 32) {
            //__m128i low0, low1, low2, low3;
            //__m128i hi0, hi1, hi2, hi3;

            ma0 = _mm256_loadu_si256((__m256i*)pa0);
            ma1 = _mm256_loadu_si256((__m256i*)pa1);

            mb0 = _mm256_loadu_si256((__m256i*)pb0);
            mb1 = _mm256_loadu_si256((__m256i*)pb1);
            mb2 = _mm256_loadu_si256((__m256i*)pb2);
            mb3 = _mm256_loadu_si256((__m256i*)pb3);
            mb4 = _mm256_loadu_si256((__m256i*)pb4);
            mb5 = _mm256_loadu_si256((__m256i*)pb5);
            mb6 = _mm256_loadu_si256((__m256i*)pb6);
            mb7 = _mm256_loadu_si256((__m256i*)pb7);

            _mm_prefetch((char *)pa0 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pa1 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb0 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb1 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb2 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb3 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb4 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb5 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb6 + 32, _MM_HINT_T0);
            _mm_prefetch((char *)pb7 + 32, _MM_HINT_T0);

            mc0 = _mm256_maddubs_epi16(ma0, mb0);
            sum0 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum0);

            mc0 = _mm256_maddubs_epi16(ma0, mb1);
            sum1 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum1);

            mc0 = _mm256_maddubs_epi16(ma0, mb2);
            sum2 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum2);

            mc0 = _mm256_maddubs_epi16(ma0, mb3);
            sum3 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum3);

            mc0 = _mm256_maddubs_epi16(ma0, mb4);
            sum4 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum4);

            mc0 = _mm256_maddubs_epi16(ma0, mb5);
            sum5 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum5);

            mc0 = _mm256_maddubs_epi16(ma0, mb6);
            sum6 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum6);

            mc0 = _mm256_maddubs_epi16(ma0, mb7);
            sum7 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum7);

            //
            mc0 = _mm256_maddubs_epi16(ma1, mb0);
            sum8 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum8);

            mc0 = _mm256_maddubs_epi16(ma1, mb1);
            sum9 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum9);

            mc0 = _mm256_maddubs_epi16(ma1, mb2);
            sum10 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum10);

            mc0 = _mm256_maddubs_epi16(ma1, mb3);
            sum11 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum11);

            mc0 = _mm256_maddubs_epi16(ma1, mb4);
            sum12 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum12);

            mc0 = _mm256_maddubs_epi16(ma1, mb5);
            sum13 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum13);

            mc0 = _mm256_maddubs_epi16(ma1, mb6);
            sum14 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum14);

            mc0 = _mm256_maddubs_epi16(ma1, mb7);
            sum15 = _mm256_add_epi32(_mm256_madd_epi16(mc0, tmp_short), sum15);

            //
            pa0+=32; pa1+=32; //pa2+=32; pa3+=32;
            pb0+=32; pb1+=32; pb2+=32; pb3+=32;
            pb4+=32; pb5+=32; pb6+=32; pb7+=32;
        }

        sum0 = _mm256_add_epi32(sum0, _mm256_permute2x128_si256(sum0, sum0, 0x81));
        sum0 = _mm256_add_epi32(sum0, _mm256_srli_si256(sum0, 8));
        sum0 = _mm256_add_epi32(sum0, _mm256_srli_si256(sum0, 4));

        sum1 = _mm256_add_epi32(sum1, _mm256_permute2x128_si256(sum1, sum1, 0x81));
        sum1 = _mm256_add_epi32(sum1, _mm256_srli_si256(sum1, 8));
        sum1 = _mm256_add_epi32(sum1, _mm256_srli_si256(sum1, 4));

        sum2 = _mm256_add_epi32(sum2, _mm256_permute2x128_si256(sum2, sum2, 0x81));
        sum2 = _mm256_add_epi32(sum2, _mm256_srli_si256(sum2, 8));
        sum2 = _mm256_add_epi32(sum2, _mm256_srli_si256(sum2, 4));

        sum3 = _mm256_add_epi32(sum3, _mm256_permute2x128_si256(sum3, sum3, 0x81));
        sum3 = _mm256_add_epi32(sum3, _mm256_srli_si256(sum3, 8));
        sum3 = _mm256_add_epi32(sum3, _mm256_srli_si256(sum3, 4));

        sum4 = _mm256_add_epi32(sum4, _mm256_permute2x128_si256(sum4, sum4, 0x81));
        sum4 = _mm256_add_epi32(sum4, _mm256_srli_si256(sum4, 8));
        sum4 = _mm256_add_epi32(sum4, _mm256_srli_si256(sum4, 4));

        sum5 = _mm256_add_epi32(sum5, _mm256_permute2x128_si256(sum5, sum5, 0x81));
        sum5 = _mm256_add_epi32(sum5, _mm256_srli_si256(sum5, 8));
        sum5 = _mm256_add_epi32(sum5, _mm256_srli_si256(sum5, 4));

        sum6 = _mm256_add_epi32(sum6, _mm256_permute2x128_si256(sum6, sum6, 0x81));
        sum6 = _mm256_add_epi32(sum6, _mm256_srli_si256(sum6, 8));
        sum6 = _mm256_add_epi32(sum6, _mm256_srli_si256(sum6, 4));

        sum7 = _mm256_add_epi32(sum7, _mm256_permute2x128_si256(sum7, sum7, 0x81));
        sum7 = _mm256_add_epi32(sum7, _mm256_srli_si256(sum7, 8));
        sum7 = _mm256_add_epi32(sum7, _mm256_srli_si256(sum7, 4));

        sum8 = _mm256_add_epi32(sum8, _mm256_permute2x128_si256(sum8, sum8, 0x81));
        sum8 = _mm256_add_epi32(sum8, _mm256_srli_si256(sum8, 8));
        sum8 = _mm256_add_epi32(sum8, _mm256_srli_si256(sum8, 4));

        sum9 = _mm256_add_epi32(sum9, _mm256_permute2x128_si256(sum9, sum9, 0x81));
        sum9 = _mm256_add_epi32(sum9, _mm256_srli_si256(sum9, 8));
        sum9 = _mm256_add_epi32(sum9, _mm256_srli_si256(sum9, 4));

        sum10 = _mm256_add_epi32(sum10, _mm256_permute2x128_si256(sum10, sum10, 0x81));
        sum10 = _mm256_add_epi32(sum10, _mm256_srli_si256(sum10, 8));
        sum10 = _mm256_add_epi32(sum10, _mm256_srli_si256(sum10, 4));

        sum11 = _mm256_add_epi32(sum11, _mm256_permute2x128_si256(sum11, sum11, 0x81));
        sum11 = _mm256_add_epi32(sum11, _mm256_srli_si256(sum11, 8));
        sum11 = _mm256_add_epi32(sum11, _mm256_srli_si256(sum11, 4));

        sum12 = _mm256_add_epi32(sum12, _mm256_permute2x128_si256(sum12, sum12, 0x81));
        sum12 = _mm256_add_epi32(sum12, _mm256_srli_si256(sum12, 8));
        sum12 = _mm256_add_epi32(sum12, _mm256_srli_si256(sum12, 4));

        sum13 = _mm256_add_epi32(sum13, _mm256_permute2x128_si256(sum13, sum13, 0x81));
        sum13 = _mm256_add_epi32(sum13, _mm256_srli_si256(sum13, 8));
        sum13 = _mm256_add_epi32(sum13, _mm256_srli_si256(sum13, 4));

        sum14 = _mm256_add_epi32(sum14, _mm256_permute2x128_si256(sum14, sum14, 0x81));
        sum14 = _mm256_add_epi32(sum14, _mm256_srli_si256(sum14, 8));
        sum14 = _mm256_add_epi32(sum14, _mm256_srli_si256(sum14, 4));

        sum15 = _mm256_add_epi32(sum15, _mm256_permute2x128_si256(sum15, sum15, 0x81));
        sum15 = _mm256_add_epi32(sum15, _mm256_srli_si256(sum15, 8));
        sum15 = _mm256_add_epi32(sum15, _mm256_srli_si256(sum15, 4));

        pc[0] = _mm256_extract_epi32(sum0, 0);
        pc[1] = _mm256_extract_epi32(sum1, 0);
        pc[2] = _mm256_extract_epi32(sum2, 0);
        pc[3] = _mm256_extract_epi32(sum3, 0);
        pc[4] = _mm256_extract_epi32(sum4, 0);
        pc[5] = _mm256_extract_epi32(sum5, 0);
        pc[6] = _mm256_extract_epi32(sum6, 0);
        pc[7] = _mm256_extract_epi32(sum7, 0);

        pc[ldc+0] = _mm256_extract_epi32(sum8, 0);
        pc[ldc+1] = _mm256_extract_epi32(sum9, 0);
        pc[ldc+2] = _mm256_extract_epi32(sum10, 0);
        pc[ldc+3] = _mm256_extract_epi32(sum11, 0);
        pc[ldc+4] = _mm256_extract_epi32(sum12, 0);
        pc[ldc+5] = _mm256_extract_epi32(sum13, 0);
        pc[ldc+6] = _mm256_extract_epi32(sum14, 0);
        pc[ldc+7] = _mm256_extract_epi32(sum15, 0);

    }
}
}

0 个答案:

没有答案