我正在尝试使用Neon内在函数为ARM A8处理器编写优化的点积,但我遇到了一些麻烦。首先,是否有任何库已经实现了这个?我的代码似乎有效,但在运行时会导致一些安静的失败 - 我最好的猜测是因为与未经优化的代码相比,精度略有下降。有没有更好的方法来完成我想要做的事情?我将非常感谢任何帮助或建议。提前谢谢。
这个特殊的点积是一个32位浮点* 32位浮点复数。
这是未经优化的代码:
double sum_re = 0.0;
double sum_im = 0.0;
for(int i=0; i<len; i++, src1++, src2++)
{
sum_re += *src1 * src2->re;
sum_im += *src1 * src2->im;
}
这是我的优化版本:
float sum_re = 0.0;
float sum_im = 0.0;
float to_sum_re[4] = {0,0,0,0};
float to_sum_im[4] = {0,0,0,0};
float32x4_t tmp_sum_re, tmp_sum_im, source1;
float32x4x2_t source2;
tmp_sum_re = vld1q_f32(to_sum_re);
tmp_sum_im = vld1q_f32(to_sum_im);
int i = 0;
while (i < (len & ~3)) {
source1 = vld1q_f32(&src1[i]);
source2 = vld2q_f32((const float32_t*)&src2[i]);
tmp_sum_re = vmlaq_f32(tmp_sum_re, source1, source2.val[0]);
tmp_sum_im = vmlaq_f32(tmp_sum_im, source1, source2.val[1]);
i += 4;
}
if (len & ~3) {
vst1q_f32(to_sum_re, tmp_sum_re);
vst1q_f32(to_sum_im, tmp_sum_im);
sum_re += to_sum_re[0] + to_sum_re[1] + to_sum_re[2] + to_sum_re[3];
sum_im += to_sum_im[0] + to_sum_im[1] + to_sum_im[2] + to_sum_im[3];
}
while (i < len)
{
sum_re += src1[i] * src2[i].re;
sum_im += src1[i] * src2[i].im;
i++;
}
答案 0 :(得分:5)
如果您使用的是iOS,请在Accelerate框架中使用vDSP_zrdotpr。 (vDSP_zrdotpr返回带有复矢量的实矢量的点积。还有其他变量,例如实数到实数或复数到复数。)
当然会失去精确度;你的未经优化的代码会累积双精度和,而NEON代码会累积单精度和。
即使没有精确更改,结果也会有所不同,因为以不同顺序执行浮点运算会产生不同的舍入误差。 (对于整数也是如此;如果计算7/3 * 5,则得到10,但5 * 7/3为11。)
有一些算法可以进行浮点运算,减少误差。但是,对于高性能点阵产品,您通常会遇到困难。
一种选择是使用双精度NEON指令进行算术运算。当然,这不会像单精度NEON一样快,但它会比标量(非NEON)代码更快。
答案 1 :(得分:0)
至于其他实现,还有来自ARM的NEON OpenMAX DL实现。 链接到http://www.arm.com/community/multimedia/standards-apis.php。
下载需要注册,格式是RVCT汇编程序,但是为了查看一组如何使用NEON(包括点积实现)的例子,它非常好。
答案 2 :(得分:0)
这是我做过的,现在是商业产品。希望它有所帮助。唯一的要求是两个被乘数(src1,srcs-&gt; re)必须是四的倍数。
float dotProduct4 (const float *a, const float *b, int n) {
float net1D=0.0f;
assert(n%4==0); // required floats 'a' & 'b' to be multiple of 4
#ifdef __ARM_NEON__
asm volatile (
"vmov.f32 q8, #0.0 \n\t" // zero out q8 register
"1: \n\t"
"subs %3, %3, #4 \n\t" // we load 4 floats into q0, and q2 register
"vld1.f32 {d0,d1}, [%1]! \n\t" // loads q0, update pointer *a
"vld1.f32 {d4,d5}, [%2]! \n\t" // loads q2, update pointer *b
"vmla.f32 q8, q0, q2 \n\t" // store four partial sums in q8
"bgt 1b \n\t" // loops to label 1 until n==0
"vpadd.f32 d0, d16, d17 \n\t" // pairwise add 4 partial sums in q8, store in d0
"vadd.f32 %0, s0, s1 \n\t" // add 2 partial sums in d0, store result in return variable net1D
: "=w"(net1D) // output
: "r"(a), "r"(b), "r"(n) // input
: "q0", "q2", "q8"); // clobber list
#else
for (int k=0; k < n; k++) {
net1D += a[k] * b[k];
}
#endif
return net1D;
}