我有一个对大量字节进行操作的算法。作为一个预处理步骤,我需要为给定的索引创建一个计数,该位数是到目前为止在数组中设置的频率。
我可以使用以下(伪)代码在C中执行此操作:
input: uint8_t values[COUNT];
output: uint32_t bitsum[COUNT+1][8];
(bitsum[i][x] is the counter for the x-th bit being set in
the PRECEEDING i bytes -- this makes bitsum[0][x] == 0)
// we skip first row
for (int i=1; i < COUNT+1; i++) {
for (int bit=0; bit < 8; bit++) {
bitsum[i][bit] = bitsum[i-1][bit];
if (values[i-1] & (1 << bit) != 0) {
bitsum[i][bit]++;
}
}
}
然而,我很想知道我可以使用NEON SIMD更快地实现这一目标。不幸的是,我对此很陌生,所以我无法解决这个问题(但是?)并寻求一些帮助。甚至可以在NEON中这样做吗?
更新
尝试在C中加速此代码,我相信以下方法是最快的(当然,没有展开内部for循环):
// pre-calculate lookup-table
uint16_t lookup[256][8];
for (int value=0; value < 256; value++) {
for (int bit=0; bit < 8; bit++) {
if (value & (1 << bit) != 0) {
lookup[value][bit]++;
}
}
}
// create sum
for (int i=1; i < COUNT+1; i++) {
for (int bit=0; bit < 8; bit++) {
bitsum[i][bit] = bitsum[i-1][bit] + lookup[values[i-1]][bit];
}
}
除了查找表访问外,这看起来对SIMD来说是理想的 - 至少我找不到在NEON中这样做的方法。
答案 0 :(得分:1)
您可以使用VTBL
和VTBX
指令在NEON中执行表查找,但它们仅对具有少量条目的查找表有用。在针对NEON进行优化时,通常最好寻找一种在运行时计算值的方法,而不是使用表格。
在此示例中,可以直接在运行时计算查找。该功能基本上是
int lookup(int val, int bit) { return (val & (1<<bit) >> bit); }
可轻松转换为NEON SIMD。
因此,您的函数可以使用NEON内在函数实现,如下所示:
#include <arm_neon.h>
void f(uint32_t *output, const uint8_t *input, int length)
{
static const uint8_t mask_vals[] = { 0x1, 0x2, 0x4, 0x8,
0x10, 0x20, 0x40, 0x80 };
/* NEON shifts are left shifts, and we want a right shift,
so use negative numbers here */
static const int8_t shift_vals[] = { 0, -1, -2, -3, -4, -5, -6, -7 };
/* constants we need in the main loop */
uint8x8_t mask = vld1_u8(mask_vals);
int8x8_t shift = vld1_s8(shift_vals);
/* accumulators for results, bits 0-3 in cumul1, bits 4-7 in cumul2 */
uint32x4_t cumul1 = vdupq_n_u32(0);
uint32x4_t cumul2 = vdupq_n_u32(0);
for (int i = 0; i < length; i++)
{
uint8x8_t v = vld1_dup_u8(input+i);
/* this gives 0 or 1 in each lane, depending on whether the
appropriate bit is set */
uint8x8_t incr = vshl_u8(vand_u8(v, mask), shift);
/* widen to 16 bits */
uint16x8_t incr_w = vmovl_u8(incr);
/* increment the accumulators */
cumul1 = vaddw_u16(cumul1, vget_low_u16(incr_w));
cumul2 = vaddw_u16(cumul2, vget_high_u16(incr_w));
/* store the accumulator values */
vst1q_u32(output + i*8, cumul1);
vst1q_u32(output + i*8 + 4, cumul2);
}
}
免责声明:此代码已编译,但我尚未对其进行测试。