logsumexp算法是一种计算以下表达式的技术:
log( sum( exp( v ) ) )
,
其中exp
函数以元素方式应用于向量v
。该算法的典型实现是直观的标识:
max(v) + log( sum( exp( v - max(v) ) ) )
,
其中max(v)
是v
中最大的元素。此版本数值准确性的提升来自于log1p
在1到2之间时使用sum( exp( v - max(v) ) )
的能力。当v
中的元素数量很大而没有他们是主导。这激发了我为logsumexp
制作以下递归算法:
logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum( exp( lesserhalf(v) - logsumexp(greaterhalf(v)) ) ) )
,
greaterhalf(v)
返回v
中大于另一半的元素的一半(四舍五入)(lesserhalf(v)
返回元素greaterhalf
不返回)。当v
中只有一个元素,其中logsumexp(v) = v
时,递归终止。当然,可以做很多优化,包括:对v
元素进行一次排序,缓存exp(v)
,并从递归中切换到循环。
是否有可能证明这种基于递归的算法可以进行大量优化,在某种意义上在数值上是最优的?具体来说,当最终转换为有限精度时,任意精度计算会下溢/溢出时,它最小化舍入误差并且只有下溢/溢出。
这是C中的一个具体实现,它可以提高性能(logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum(exp(lesserhalf(v))) / sum(exp(greaterhalf(v))) )
):
#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#define MAX(a, b) \
({ __typeof__ (a) x = (a); \
__typeof__ (b) y = (b); \
x > y ? x : y; })
#define MIN(a, b) \
({ __typeof__ (a) x = (a); \
__typeof__ (b) y = (b); \
x <= y ? x : y; })
//Swap the normal comparison order to get a decreasing sort
int doubleComp( const void *p1, const void *p2 ) {
double *d1 = (double *)p1;
double *d2 = (double *)p2;
if (*d1 < *d2) return -1;
if (*d1 > *d2) return 1;
if (*d1 == *d2) return 0;
return 1;//Moves NANs to the end of the list
}
double LogSumExp( double *restrict Inputs, size_t lenInputs ) {
double biggestInput, result;
size_t i;
//Preflight for NANs
for ( i = 0; i < lenInputs; i++ ) {
if ( isnan(Inputs[i]) ) {
return NAN;
}
}
if ( lenInputs == 2 ) {
biggestInput = MAX( Inputs[0], Inputs[1] );
result = biggestInput + log1p( exp( MIN(Inputs[0], Inputs[1]) - biggestInput ) );
}
else if ( lenInputs > 2 ) {
double *restrict sortedInputs, *restrict Sums, *restrict curSum;
size_t lenSums;
double bigpart;
size_t iSum, startpoint, stoppoint;
//Allocate needed memory
lenSums = (unsigned int) ceil( log2((double) lenInputs) ) + 1;
sortedInputs = (double *) malloc( lenInputs * sizeof (double) );
Sums = (double *) malloc( lenSums * sizeof (double) );
if ( sortedInputs == NULL || Sums == NULL ) {
fprintf(stderr, "Memory allocation failed in LogSumExp.\n");
abort();
}
//Sort the inputs, without disturbing the actual Inputs
memcpy( sortedInputs, Inputs, lenInputs * sizeof (double) );
qsort( sortedInputs, lenInputs, sizeof (double),
doubleComp );
//Subtract the biggest input to control possible overflow
biggestInput = sortedInputs[lenInputs - 1];
for ( i = 0; i < lenInputs; ++i ) {
sortedInputs[i] -= biggestInput;
}
//Produce the intermediate Sums
stoppoint = 0;
for ( iSum = 0; iSum < lenSums; iSum++ ) {
curSum = &( Sums[iSum] );
*curSum = 0.0;
startpoint = stoppoint;
stoppoint = lenInputs - startpoint;
stoppoint = ( stoppoint >> 1 );
stoppoint = MAX( stoppoint, 1 );
stoppoint += startpoint;
for ( i = startpoint; i < stoppoint; i++ ) {
*curSum += exp( sortedInputs[i] );
}
}
//Digest the Sums into results
result = 0.0;
for ( iSum = 0; iSum < lenSums - 1; iSum++ ) {
bigpart = 0.0;
for ( i = iSum + 1; i < lenSums; i++ ) {
bigpart += Sums[i];
}
if ( Sums[iSum] > 0.0 ) {
result += log1p( Sums[iSum] / bigpart );
}
}
free( Sums );
free( sortedInputs );
result += biggestInput;
}
else if ( lenInputs == 1 ) {
result = Inputs[0];
}
else {
result = NAN;
}
return result;
}
同样的python实现(取决于numpy):
from numpy import *
def LogSumExp( inputs ):
if any( isnan(inputs) ) :
result = float("nan")
elif type(inputs) == type(float(1.0)) or len(inputs) == 1:
result = inputs
elif len(inputs) == 2:
smallval, bigval = ( min(inputs), max(inputs) )
result = bigval + log1p( exp( smallval - bigval ) )
elif len(inputs) > 2:
srtInputs = sort( inputs )
bigval = srtInputs[-1]
srtInputs -= bigval
expInputs = exp( srtInputs )
result = 0.0
startpoint = 0
endpoint = len(inputs) // 2
while startpoint < len(inputs) - 1:
smallpart = sum(expInputs[startpoint:endpoint])
if smallpart > 0.0:
result += log1p( smallpart / \
sum(expInputs[endpoint:]) )
startpoint = endpoint
endpoint = len(inputs) - startpoint
endpoint = endpoint // 2
endpoint = max( endpoint, 1 )
endpoint += startpoint
result += bigval
else:
result = float("nan")
return result