快速“缩小”三维张量指数的方法

时间:2013-07-31 17:08:55

标签: c++ c bit-manipulation

对于C或C ++来说,这是一个有点棘手的问题。我在Ubuntu 12.04.2下运行GCC 4.6.3。

对于具有以下形式的三维张量,我有一个内存访问索引p

p = (i<<(2*N)) + (j<<N) + k

此处0 <= i,j,k < (1<<N)N有一些正整数。

现在我想计算i>>S, j>>S, k>>S 0 < S < N的“缩小”内存访问索引,其中包括:{/ p>

q = ((i>>S)<<(2*(N-S))) + ((j>>S)<<(N-S)) + (k>>S)

q计算p的最快方法是什么(事先不知道i,j,k)?我们可以假设0 < N <= 10(即p是32位整数)。我会对N=8的快速方法特别感兴趣(即i,j,k是8位整数)。 NS都是编译时常量。

N=8S=4的示例:

unsigned int p = 240407; // this is (3<<16) + (171<<8) + 23;
unsigned int q = 161; // this is (0<<8) + (10<<4) + 1

3 个答案:

答案 0 :(得分:1)

直截了当的方式,8个操作(其他是对常量的操作):

M = (1<<(N-S)) - 1;                     // A mask with S lowest bits.
q = (  ((p & (M<<(2*N+S))) >> (3*S))    // Mask 'i', shift to new position.
     + ((p & (M<<(  N+S))) >> (2*S))    // Likewise for 'j'.
     + ((p & (M<<     S))  >>    S));   // Likewise for 'k'.

看起来很复杂,但实际上并不容易(至少对我而言)让所有常数都正确无误。

要创建具有较少操作的公式,我们观察到将数字向左移位U位与乘以1<<U相同。因此,由于乘法分布,乘以((1<<U1) + (1<<U2) + ...)与向左移动U1U2,...并将所有内容相加。

因此,我们可以尝试屏蔽ijk所需的部分,&#34; shift&#34;通过一次乘法将它们全部放到相对于彼此的正确位置,然后将结果移到右边,到达最终目的地。这为我们提供了三种从q计算p的操作。

不幸的是,有一些限制,特别是对于我们试图同时获得所有三个的情况。当我们将数字相加(间接地,通过将​​几个乘法器加在一起)时,我们必须确保只能在一个数字中设置位,否则我们将得到错误的结果。如果我们尝试一次添加(间接)三个正确移位的数字,我们就有了这个:

iiiii...........jjjjj...........kkkkk.......
 N-S      S      N-S      S      N-S
.....jjjjj...........kkkkk................
 N-S  N-S      S      N-S
..........kkkkk...............
 N-S  N-S  N-S

请注意,第二个和第三个数字中的左侧更远是ij的位,但我们会忽略它们。为此,我们假设乘法在x86上工作:乘以两个类型T给出一些类型T,只有实际结果的最低位(等于结果,如果没有溢出)

因此,为了确保第三个数字中的k位与第一个数字中的j位不重叠,我们需要3*(N-S) <= N,即S >= 2*N/3 N = 8将我们限制为S >= 6(转换后每个组件只有一到两位;不知道您是否曾使用过低精度)。

但是,如果S >= 2*N/3,我们只能使用3个操作:

// Constant multiplier to perform three shifts at once.
F = (1<<(32-3*N)) + (1<<(32-3*N+S)) + (1<<(32-3*N+2*S));
// Mask, shift/combine with multipler, right shift to destination.
q = (((p & ((M<<(2*N+S)) + (M<<(N+S)) + (M<<S))) * F)
     >> (32-3*(N-S)));

如果S的约束太严格(可能是这样),我们可以将第一个和第二个公式结合起来:使用第二种方法计算ik,然后添加来自第一个公式的j。在这里,我们需要在以下数字中不重叠:

iiiii...............kkkkk.......
 N-S   S   N-S   S   N-S
..........kkkkk...............
 N-S  N-S  N-S

即。 3*(N-S) <= 2*NS >= N / 3,或N = 8严格S >= 3。公式如下:

// Constant multiplier to perform two shifts at once.
F = (1<<(32-3*N)) + (1<<(32-3*N+2*S));
// Mask, shift/combine with multipler, right shift to destination
// and then add 'j' from the straightforward formula.
q = ((((p & ((M<<(2*N+S)) + (M<<S))) * F) >> (32-3*(N-S)))
     + ((p & (M<<(N+S))) >> (2*S)));

此公式也适用于S = 4

的示例

这是否比直接方法更快取决于架构。另外,我不知道C ++是否保证了假定的乘法溢出行为。最后,您需要确保值无符号且完全 32位才能使公式生效。

答案 1 :(得分:0)

是否符合您的要求?

#include <cstdint>
#include <iostream>

uint32_t to_q_from_p(uint32_t p, uint32_t N, uint32_t S)
{
   uint32_t mask = ~(~0 << N);
   uint32_t k = p &mask;
   uint32_t j = (p >> N)& mask;
   uint32_t i = (p >> 2*N)&mask;
   return ((i>>S)<<(2*(N-S))) + ((j>>S)<<(N-S)) + (k>>S);;
}

int main()
{
   uint32_t p = 240407;

   uint32_t q = to_q_from_p(p, 8, 4);

   std::cout << q << '\n';

}

如果你假设N总是8并且整数是小端,那么它可以是

uint32_t to_q_from_p(uint32_t p, uint32_t S)
{
   auto ptr = reinterpret_cast<uint8_t*>(&p);
   return ((ptr[2]>>S)<<(2*(8-S))) + ((ptr[1]>>S)<<(8-S)) + (ptr[0]>>S);
}

答案 2 :(得分:0)

如果你不关心兼容性,对于N = 8你可以得到i,j,k那样:

 int p = .... 
 unsigned char *bytes = (char *)&p;

现在kbytes[0]jbytes[1]ibytes[2](我的机器上发现了小端)。但我认为更好的方法是......那样(我们有N_MASK = 2 ^ N - 1)

 int q;
 q = ( p & N_MASK ) >> S;
 p >>= N;
 q |= ( ( p & N_MASK ) >> S ) << S;
 p >>= N;
 q |= ( ( p & N_MASK ) >> S ) << 2*S;