C ++中的卷积能否比MATLAB conv2更快?

时间:2016-03-12 17:12:06

标签: matlab convolution

最近我正致力于加速mex中的卷积。我使用opencv filter 2d但它显示比MATLAB慢,然后我根据SSE和openmp编写自己的卷积,似乎在同一台机器上,我不能比MATLAB快。我不能使用FFT,因为我的卷积内核是5 * 5,我也不能分离内核,因为它不是内核,只是内核中的随机元素。我可以禁食吗?

实际上,如果我使用更好的机器,例如机器有更多的线程,我可以达到与MATLAB(conv2)相同的速度。我想问一下我是否能用普通的机器实现这个速度(我现在使用2内核8线程的机器)` MATLAB测试代码:

 i_h = 2000;
 i_w = 1000;   
 a1 = ones(i_h,i_w);
 for k=1:2000
     for i=1:i_h
         for j=1:i_w
             a1(i,j)=2+rand();
         end
     end
     a1 = single(a1);
     b = ones(5,5);
     for i=1:5
         for j=1:5
             b(i,j)=3+rand();
         end
     end
     b = single(b);
     tic, 
     c = conv2(a1,b,'same');
     toc,

     tic,
     c1 = Verify_01_20160310(a1,b);
     toc,
     c3=c-c1;
     diff = max(abs(c3(:)));
     disp(diff);
 end

` Mex代码:

    // SSE_Conv_Mex_01_20160310.cpp : Defines the entry point for the console application.
//

#include "SSE_1d_conv.h"
#include <iostream>
#include <windows.h>
#include <stdlib.h>
#include <smmintrin.h>
#include <omp.h>
#include "time.h"
#include <stdio.h>
#include <tchar.h>
#include "mex.h"
#include <ctime>

time_t t0,t1;


bool Conv2d_32f_SSE(float *psh_2D_source, const int i_W, const int i_H,
                    float *psh_2D_dest, float *piKernel25, float *pi_WorkingLines)
{
#pragma omp parallel
    {
        float *pshLine_in_prev2, *pshLine_in_prev1, *pshLine_in_curr, 
              *pshLine_in_next1, *pshLine_in_next2, *piLine_out;

        float *piLine_prev2, *piLine_prev1, *piLine_curr,
              *piLine_next1, *piLine_next2; 

        float *piLine_res_prev2, *piLine_res_prev1, *piLine_res_curr,
              *piLine_res_next1, *piLine_res_next2;


        int i,j,i_current_thread_num;


    #pragma omp for
        for(i=0; i<i_H; i++)
        {
            piLine_out = psh_2D_dest + i*i_W;


            i_current_thread_num = omp_get_thread_num();

            piLine_prev2 = pi_WorkingLines + i_current_thread_num*5*i_W;
            if(i>=2)
            {
                pshLine_in_prev2 = psh_2D_source + (i-2)*i_W;
                Conv1d_1line_SSE(pshLine_in_prev2, i_W, piKernel25+20, piLine_prev2);
            }
            else memset(piLine_prev2,0,i_W*sizeof(float));
            piLine_res_prev2 = piLine_prev2;

            piLine_prev1 = pi_WorkingLines + (i_current_thread_num*5+1)*i_W;
            if(i>=1)
            {
                pshLine_in_prev1 = psh_2D_source + (i-1)*i_W;
                Conv1d_1line_SSE(pshLine_in_prev1, i_W, piKernel25+15, piLine_prev1);
            }
            else memset(piLine_prev1,0,i_W*sizeof(float));
            piLine_res_prev1 = piLine_prev1;

            piLine_curr  = pi_WorkingLines + (i_current_thread_num*5+2)*i_W;
            pshLine_in_curr  = psh_2D_source + i*i_W;
            Conv1d_1line_SSE(pshLine_in_curr, i_W, piKernel25+10, piLine_curr);
            piLine_res_curr = piLine_curr;

            piLine_next2 = pi_WorkingLines + (i_current_thread_num*5+4)*i_W;
            if(i<i_H-2)
            {
                pshLine_in_next2 = psh_2D_source + (i+2)*i_W;
                Conv1d_1line_SSE(pshLine_in_next2, i_W, piKernel25, piLine_next2);

            }
            else memset(piLine_next2,0,i_W*sizeof(float));
            piLine_res_next2 = piLine_next2;

            piLine_next1 = pi_WorkingLines + (i_current_thread_num*5+3)*i_W; 
            if(i<i_H-1)
            {

                pshLine_in_next1 = psh_2D_source + (i+1)*i_W;
                Conv1d_1line_SSE(pshLine_in_next1, i_W, piKernel25+5, piLine_next1);
            }
            else  memset(piLine_next1,0,i_W*sizeof(float));
            piLine_res_next1 = piLine_next1;

            for(j=0; j<i_W; piLine_res_prev2 += 4, piLine_res_prev1 += 4,piLine_res_curr +=4, piLine_res_next1 +=4, piLine_res_next2 +=4, piLine_out +=4, j+=4)
            {
                _mm_stream_ps(piLine_out,_mm_add_ps(_mm_load_ps(piLine_res_next2),_mm_add_ps(_mm_load_ps(piLine_res_next1),
                             _mm_add_ps(_mm_load_ps(piLine_res_curr),_mm_add_ps(_mm_load_ps(piLine_res_prev1),
                                _mm_load_ps(piLine_res_prev2))))));
            }
        }
    }
    return true;
}


bool Conv1d_1line_SSE(float *pshLine_in, const int i_W, float *piKernel5, float *piLine_out)
{
    __m128 mf_zero = _mm_setzero_ps();
    __m128 mf_f32,
           mf_Multd32_1_1, mf_Multd32_1_2,
           mf_Multd32_2_1, mf_Multd32_2_2,
           mf_Multd32_1_2_prev = mf_zero, mf_Multd32_1_1_prev = mf_zero,
           mf_Multd32_2_1_prev, mf_Multd32_2_2_prev,
           mf_sum;

    __m128 mf_KernLeft2  = _mm_set1_ps(piKernel5[4]),
           mf_KernLeft1  = _mm_set1_ps(piKernel5[3]),
           mf_KernCen    = _mm_set1_ps(piKernel5[2]),
           mf_KernRight1 = _mm_set1_ps(piKernel5[1]),
           mf_KernRight2 = _mm_set1_ps(piKernel5[0]);

    int i;
    // if the variable has two 1 or 2, then the first 1 or 2 represent the Kernel before center or after center, the second 1 or 2 represent the distance between the kernel and the center kernel
    mf_f32 = _mm_load_ps(pshLine_in);

    mf_Multd32_1_2 = _mm_mul_ps(mf_f32, mf_KernLeft2);
    mf_Multd32_1_1 = _mm_mul_ps(mf_f32, mf_KernLeft1);

    mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)(&mf_Multd32_1_2), *(__m128i*)(&mf_Multd32_1_2_prev), 2*sizeof(float)), _mm_mul_ps(mf_f32, mf_KernCen));
    mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)(&mf_Multd32_1_1), *(__m128i*)(&mf_Multd32_1_1_prev), 3*sizeof(float)), mf_sum);

    mf_Multd32_1_1_prev = mf_Multd32_1_1;
    mf_Multd32_1_2_prev = mf_Multd32_1_2;

    mf_Multd32_2_2 = _mm_mul_ps(mf_f32, mf_KernRight2);
    mf_Multd32_2_1 = _mm_mul_ps(mf_f32, mf_KernRight1);

    mf_Multd32_2_2_prev = mf_Multd32_2_2;
    mf_Multd32_2_1_prev = mf_Multd32_2_1;

    for(i=4; i<i_W; i+=4)
    {
        mf_f32 = _mm_load_ps(pshLine_in+i);

        mf_Multd32_2_1 = _mm_mul_ps(mf_f32, mf_KernRight1);
        mf_Multd32_2_2 = _mm_mul_ps(mf_f32, mf_KernRight2);


        mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_2_1, *(__m128i*)&mf_Multd32_2_1_prev, 1*sizeof(float)), mf_sum);
        mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_2_2, *(__m128i*)&mf_Multd32_2_2_prev, 2*sizeof(float)), mf_sum);

        mf_Multd32_2_2_prev = mf_Multd32_2_2;
        mf_Multd32_2_1_prev = mf_Multd32_2_1;

        _mm_store_ps(piLine_out+i-4, mf_sum);

        mf_Multd32_1_1 = _mm_mul_ps(mf_f32, mf_KernLeft1); 
        mf_Multd32_1_2 = _mm_mul_ps(mf_f32, mf_KernLeft2);

        mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_1_2, *(__m128i*)&mf_Multd32_1_2_prev, 2*sizeof(float)), _mm_mul_ps(mf_f32, mf_KernCen));
        mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_1_1, *(__m128i*)&mf_Multd32_1_1_prev, 3*sizeof(float)), mf_sum);

        mf_Multd32_1_1_prev = mf_Multd32_1_1;
        mf_Multd32_1_2_prev = mf_Multd32_1_2;
    }
    mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_zero, *(__m128i*)&mf_Multd32_2_1, 1*sizeof(float)), mf_sum);
    mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_zero, *(__m128i*)&mf_Multd32_2_2, 2*sizeof(float)), mf_sum);

    _mm_store_ps(piLine_out+i-4, mf_sum);

    return true;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray* prhs[])
{

        int input1_row = (int)mxGetM(prhs[0]);
        int input1_column = (int)mxGetN(prhs[0]); 

        float *input1;


        float *input2;

        const int size_out_arr[2]={input1_row,input1_column};

        plhs[0] = mxCreateNumericArray(2,size_out_arr,mxSINGLE_CLASS,mxREAL);
        float *output; 

        int i_omp_max_threads = omp_get_max_threads();
        float *pi_WorkingLines =(float*)_mm_malloc(i_omp_max_threads*5*input1_row*sizeof(float), 32);


        input1 = (float*)mxGetData(prhs[0]);
        input2 = (float*)mxGetData(prhs[1]);
        output = (float*)mxGetData(plhs[0]);
        t0=clock();
        Conv2d_32f_SSE(input1,input1_row,input1_column,output,input2,pi_WorkingLines);
        t1=clock();
        mexPrintf("convol time: %d\n", t1-t0);
        _mm_free(pi_WorkingLines);        
}

下面是我使用conv2几乎40 * 36次的代码,因为P.nScales是36,而opts.filters是任何5 * 5 * 40三维滤波器,在repmat之后,C变为548 * 966 * 40 。 `

 for i=1:P.nScales, fs=opts.filters;
      C=repmat(P.data{i},[1 1 size(fs,4)]);           
      for j=1:size(C,3)
       C(:,:,j)=conv2(C(:,:,j),fs(:,:,j)); 
      end  
      P.data{i}=imResample(C,.5); 
  end

end

` 由于卷积使用128位寄存器并且数据类型是浮点数,因此它可以一次处理4个浮点数。这意味着,你应该将高度设置为4的整数倍(因为MATLAB存储数据首先是行,所以mex中的卷积首先处理它的行)。

0 个答案:

没有答案