我有一个随机过程,当被调用时,返回一个0到K-1之间的随机数,其中K可能相当高。我想跟踪任何结果发生的次数,并将所有计数归一化为概率分布。我希望每次调用随机过程时都这样做,以便我对随机过程的分布估计尽可能最新。
一种天真的方法可能如下:
while ( true ) {
int n = randomProcess();
++totalCount;
++count[n];
update();
do_work_with_updated_prob_vector();
}
void update() {
for ( int i = 0; i < K; ++i )
prob[i] = count[i] / static_cast<double>(totalCount);
}
然而,当K开始变大时,这种方法需要在每次概率更新时读取整个计数向量,这是由于高速缓存未命中和存储器访问成本而不希望的。我设计了另一个解决方案,在我的有限测试中,K~1000的速度提高了约30%。新的更新函数需要知道最后更新的元素的索引:
void fastUpdate(int id) {
if ( totalCount == 1 ) {
prob[id] = 1.0;
return;
}
double newProb = count[id] / static_cast<double>(totalCount - 1);
double newProbSum = 1.0 + ( newProb - prob[id] );
prob[id] = newProb;
for ( int i = 0; i < K; ++i )
prob[i] /= newProbSum
}
这种方法在理论上有效,但是我担心浮点精度误差会由于执行的不完美标准化而累积。我是否应该偶尔调用基本update
函数来摆脱它们?如果是这样,多久一次?这个错误有多大?我对这类问题几乎没有经验,我知道我不需要低估它们。
我正在编写一系列AI算法,需要学习最初未知的环境。在这种情况下,通过近似分布中看到的内容来学习环境。在每次迭代时,算法将根据新数据(不仅包括更新的prob
向量,还包括其他内容)修改其决策。由于这些值不仅可以使用,而且也可以在一次迭代中多次使用,我猜想最好先计算一次然后再使用它,这就是我对更新函数所做的事情。
此外,我想补充一点,我是否需要在每次迭代时更新prob
向量,这在这里确实不是问题。函数fastUpdate
的契约是它将进行快速更新,这就是我的问题所源自的地方。如果我不需要经常更新,我将通过在每次迭代时不调用该函数来实现。因为此刻我需要打电话给我,我正在这样做。我希望这可以澄清。
答案 0 :(得分:2)
就像一个例如,拿这个python例子:
for i in range(1000000):
x = rnd.randrange(0,10)
intar.append(x)
dblar.append(x/100.0)
intsum = 0
for i in intar:
intsum += i
dblsum = 0.0
for d in dblar:
dblsum += d
print("int: %f, dbl: %f, diff: %f" % ((intsum/100.0), dblsum, ((intsum/100.0)-dblsum)))
的产率:
int: 45012.230000, dbl: 45012.200000, diff: 0.030000
现在,我强制使用除数来确保存在一致的舍入误差。我猜测输入数据分布的性质对于确定会累积多少错误至关重要;虽然我从来没有新的或忘记了得出答案所必需的数学。由于基于编译器选项已知浮点数学的确切行为,因此应该可以在给出输入数据的概念的情况下导出错误范围。
答案 1 :(得分:2)
在添加项目时更新prob
,而不是在需要读取概率时更新它。在读取之前,请使用布尔标志来指示prob
是否需要更新。
while ( true ) {
int n = randomProcess();
++totalCount;
++count[n];
dirty = true;
}
void updateBeforeRead() {
if(dirty) {
for ( int i = 0; i < K; ++i )
prob[i] = count[i] / static_cast<double>(totalCount);
}
dirty = false;
}
}
如果您的使用情况在大量样本之间翻转,然后根据概率进行大量计算,那么这应该是有效的,同时限制舍入问题。
答案 2 :(得分:0)
行。第二次尝试......
鉴于您的测试表明更新prob
会直接提高算法的性能,您可以通过定期重置来最小化舍入错误。
void fastUpdate(int id) {
if ( totalCount == 1 ) {
prob[id] = 1.0;
fastUpdates = 0;
return;
}
if(fastUpdates<maxFastUpdates) {
double newProb = count[id] / static_cast<double>(totalCount - 1);
double newProbSum = 1.0 + ( newProb - prob[id] );
prob[id] = newProb;
for ( int i = 0; i < K; ++i )
prob[i] /= newProbSum;
++fastUpdates;
}
else {
update();
fastUpdates = 0;
}
}
答案 3 :(得分:0)
最快的操作是您不执行的操作。即使您最终使用每次更新后生成的每个值,如果您不必将更新的值写回内存,也可以保存。你也可以通过乘法而不是除法来削减几个时钟。乘法很可能比你保存的内存访问速度快。
template<int K>
class Prob
{
private:
int count[K];
int totalCount;
double multiplier;
public:
update(int id)
{
++count[id];
++totalCount;
multiplier = 1.0 / totalCount;
}
double operator[](int id)
{
return count[id] * multiplier;
}
};
Prob<K> prob;
while ( true ) {
int n = randomProcess();
prob.update(n);
// demo
double sum = 0.0;
for (int i = 0; i < K; i++)
sum += prob[i];
do_work_with_updated_prob_vector(prob);
}