最近我正致力于加速mex中的卷积。我使用opencv filter 2d但它显示比MATLAB慢,然后我根据SSE和openmp编写自己的卷积,似乎在同一台机器上,我不能比MATLAB快。我不能使用FFT,因为我的卷积内核是5 * 5,我也不能分离内核,因为它不是内核,只是内核中的随机元素。我可以禁食吗?
实际上,如果我使用更好的机器,例如机器有更多的线程,我可以达到与MATLAB(conv2)相同的速度。我想问一下我是否能用普通的机器实现这个速度(我现在使用2内核8线程的机器)` MATLAB测试代码:
i_h = 2000;
i_w = 1000;
a1 = ones(i_h,i_w);
for k=1:2000
for i=1:i_h
for j=1:i_w
a1(i,j)=2+rand();
end
end
a1 = single(a1);
b = ones(5,5);
for i=1:5
for j=1:5
b(i,j)=3+rand();
end
end
b = single(b);
tic,
c = conv2(a1,b,'same');
toc,
tic,
c1 = Verify_01_20160310(a1,b);
toc,
c3=c-c1;
diff = max(abs(c3(:)));
disp(diff);
end
` Mex代码:
// SSE_Conv_Mex_01_20160310.cpp : Defines the entry point for the console application.
//
#include "SSE_1d_conv.h"
#include <iostream>
#include <windows.h>
#include <stdlib.h>
#include <smmintrin.h>
#include <omp.h>
#include "time.h"
#include <stdio.h>
#include <tchar.h>
#include "mex.h"
#include <ctime>
time_t t0,t1;
bool Conv2d_32f_SSE(float *psh_2D_source, const int i_W, const int i_H,
float *psh_2D_dest, float *piKernel25, float *pi_WorkingLines)
{
#pragma omp parallel
{
float *pshLine_in_prev2, *pshLine_in_prev1, *pshLine_in_curr,
*pshLine_in_next1, *pshLine_in_next2, *piLine_out;
float *piLine_prev2, *piLine_prev1, *piLine_curr,
*piLine_next1, *piLine_next2;
float *piLine_res_prev2, *piLine_res_prev1, *piLine_res_curr,
*piLine_res_next1, *piLine_res_next2;
int i,j,i_current_thread_num;
#pragma omp for
for(i=0; i<i_H; i++)
{
piLine_out = psh_2D_dest + i*i_W;
i_current_thread_num = omp_get_thread_num();
piLine_prev2 = pi_WorkingLines + i_current_thread_num*5*i_W;
if(i>=2)
{
pshLine_in_prev2 = psh_2D_source + (i-2)*i_W;
Conv1d_1line_SSE(pshLine_in_prev2, i_W, piKernel25+20, piLine_prev2);
}
else memset(piLine_prev2,0,i_W*sizeof(float));
piLine_res_prev2 = piLine_prev2;
piLine_prev1 = pi_WorkingLines + (i_current_thread_num*5+1)*i_W;
if(i>=1)
{
pshLine_in_prev1 = psh_2D_source + (i-1)*i_W;
Conv1d_1line_SSE(pshLine_in_prev1, i_W, piKernel25+15, piLine_prev1);
}
else memset(piLine_prev1,0,i_W*sizeof(float));
piLine_res_prev1 = piLine_prev1;
piLine_curr = pi_WorkingLines + (i_current_thread_num*5+2)*i_W;
pshLine_in_curr = psh_2D_source + i*i_W;
Conv1d_1line_SSE(pshLine_in_curr, i_W, piKernel25+10, piLine_curr);
piLine_res_curr = piLine_curr;
piLine_next2 = pi_WorkingLines + (i_current_thread_num*5+4)*i_W;
if(i<i_H-2)
{
pshLine_in_next2 = psh_2D_source + (i+2)*i_W;
Conv1d_1line_SSE(pshLine_in_next2, i_W, piKernel25, piLine_next2);
}
else memset(piLine_next2,0,i_W*sizeof(float));
piLine_res_next2 = piLine_next2;
piLine_next1 = pi_WorkingLines + (i_current_thread_num*5+3)*i_W;
if(i<i_H-1)
{
pshLine_in_next1 = psh_2D_source + (i+1)*i_W;
Conv1d_1line_SSE(pshLine_in_next1, i_W, piKernel25+5, piLine_next1);
}
else memset(piLine_next1,0,i_W*sizeof(float));
piLine_res_next1 = piLine_next1;
for(j=0; j<i_W; piLine_res_prev2 += 4, piLine_res_prev1 += 4,piLine_res_curr +=4, piLine_res_next1 +=4, piLine_res_next2 +=4, piLine_out +=4, j+=4)
{
_mm_stream_ps(piLine_out,_mm_add_ps(_mm_load_ps(piLine_res_next2),_mm_add_ps(_mm_load_ps(piLine_res_next1),
_mm_add_ps(_mm_load_ps(piLine_res_curr),_mm_add_ps(_mm_load_ps(piLine_res_prev1),
_mm_load_ps(piLine_res_prev2))))));
}
}
}
return true;
}
bool Conv1d_1line_SSE(float *pshLine_in, const int i_W, float *piKernel5, float *piLine_out)
{
__m128 mf_zero = _mm_setzero_ps();
__m128 mf_f32,
mf_Multd32_1_1, mf_Multd32_1_2,
mf_Multd32_2_1, mf_Multd32_2_2,
mf_Multd32_1_2_prev = mf_zero, mf_Multd32_1_1_prev = mf_zero,
mf_Multd32_2_1_prev, mf_Multd32_2_2_prev,
mf_sum;
__m128 mf_KernLeft2 = _mm_set1_ps(piKernel5[4]),
mf_KernLeft1 = _mm_set1_ps(piKernel5[3]),
mf_KernCen = _mm_set1_ps(piKernel5[2]),
mf_KernRight1 = _mm_set1_ps(piKernel5[1]),
mf_KernRight2 = _mm_set1_ps(piKernel5[0]);
int i;
// if the variable has two 1 or 2, then the first 1 or 2 represent the Kernel before center or after center, the second 1 or 2 represent the distance between the kernel and the center kernel
mf_f32 = _mm_load_ps(pshLine_in);
mf_Multd32_1_2 = _mm_mul_ps(mf_f32, mf_KernLeft2);
mf_Multd32_1_1 = _mm_mul_ps(mf_f32, mf_KernLeft1);
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)(&mf_Multd32_1_2), *(__m128i*)(&mf_Multd32_1_2_prev), 2*sizeof(float)), _mm_mul_ps(mf_f32, mf_KernCen));
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)(&mf_Multd32_1_1), *(__m128i*)(&mf_Multd32_1_1_prev), 3*sizeof(float)), mf_sum);
mf_Multd32_1_1_prev = mf_Multd32_1_1;
mf_Multd32_1_2_prev = mf_Multd32_1_2;
mf_Multd32_2_2 = _mm_mul_ps(mf_f32, mf_KernRight2);
mf_Multd32_2_1 = _mm_mul_ps(mf_f32, mf_KernRight1);
mf_Multd32_2_2_prev = mf_Multd32_2_2;
mf_Multd32_2_1_prev = mf_Multd32_2_1;
for(i=4; i<i_W; i+=4)
{
mf_f32 = _mm_load_ps(pshLine_in+i);
mf_Multd32_2_1 = _mm_mul_ps(mf_f32, mf_KernRight1);
mf_Multd32_2_2 = _mm_mul_ps(mf_f32, mf_KernRight2);
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_2_1, *(__m128i*)&mf_Multd32_2_1_prev, 1*sizeof(float)), mf_sum);
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_2_2, *(__m128i*)&mf_Multd32_2_2_prev, 2*sizeof(float)), mf_sum);
mf_Multd32_2_2_prev = mf_Multd32_2_2;
mf_Multd32_2_1_prev = mf_Multd32_2_1;
_mm_store_ps(piLine_out+i-4, mf_sum);
mf_Multd32_1_1 = _mm_mul_ps(mf_f32, mf_KernLeft1);
mf_Multd32_1_2 = _mm_mul_ps(mf_f32, mf_KernLeft2);
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_1_2, *(__m128i*)&mf_Multd32_1_2_prev, 2*sizeof(float)), _mm_mul_ps(mf_f32, mf_KernCen));
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_Multd32_1_1, *(__m128i*)&mf_Multd32_1_1_prev, 3*sizeof(float)), mf_sum);
mf_Multd32_1_1_prev = mf_Multd32_1_1;
mf_Multd32_1_2_prev = mf_Multd32_1_2;
}
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_zero, *(__m128i*)&mf_Multd32_2_1, 1*sizeof(float)), mf_sum);
mf_sum = _mm_add_ps(*(__m128*)&_mm_alignr_epi8(*(__m128i*)&mf_zero, *(__m128i*)&mf_Multd32_2_2, 2*sizeof(float)), mf_sum);
_mm_store_ps(piLine_out+i-4, mf_sum);
return true;
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray* prhs[])
{
int input1_row = (int)mxGetM(prhs[0]);
int input1_column = (int)mxGetN(prhs[0]);
float *input1;
float *input2;
const int size_out_arr[2]={input1_row,input1_column};
plhs[0] = mxCreateNumericArray(2,size_out_arr,mxSINGLE_CLASS,mxREAL);
float *output;
int i_omp_max_threads = omp_get_max_threads();
float *pi_WorkingLines =(float*)_mm_malloc(i_omp_max_threads*5*input1_row*sizeof(float), 32);
input1 = (float*)mxGetData(prhs[0]);
input2 = (float*)mxGetData(prhs[1]);
output = (float*)mxGetData(plhs[0]);
t0=clock();
Conv2d_32f_SSE(input1,input1_row,input1_column,output,input2,pi_WorkingLines);
t1=clock();
mexPrintf("convol time: %d\n", t1-t0);
_mm_free(pi_WorkingLines);
}
下面是我使用conv2几乎40 * 36次的代码,因为P.nScales是36,而opts.filters是任何5 * 5 * 40三维滤波器,在repmat之后,C变为548 * 966 * 40 。 `
for i=1:P.nScales, fs=opts.filters;
C=repmat(P.data{i},[1 1 size(fs,4)]);
for j=1:size(C,3)
C(:,:,j)=conv2(C(:,:,j),fs(:,:,j));
end
P.data{i}=imResample(C,.5);
end
end
` 由于卷积使用128位寄存器并且数据类型是浮点数,因此它可以一次处理4个浮点数。这意味着,你应该将高度设置为4的整数倍(因为MATLAB存储数据首先是行,所以mex中的卷积首先处理它的行)。