Logsumexp的这种实现在数学上是最优的吗?

时间:2016-09-15 01:01:56

标签: algorithm precision numerical-methods

logsumexp算法是一种计算以下表达式的技术:

log( sum( exp( v ) ) )

其中exp函数以元素方式应用于向量v。该算法的典型实现是直观的标识:

max(v) + log( sum( exp( v - max(v) ) ) )

其中max(v)v中最大的元素。此版本数值准确性的提升来自于log1p在1到2之间时使用sum( exp( v - max(v) ) )的能力。当v中的元素数量很大而没有他们是主导。这激发了我为logsumexp制作以下递归算法:

logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum( exp( lesserhalf(v) - logsumexp(greaterhalf(v)) ) ) )

greaterhalf(v)返回v中大于另一半的元素的一半(四舍五入)(lesserhalf(v)返回元素greaterhalf不返回)。当v中只有一个元素,其中logsumexp(v) = v时,递归终止。当然,可以做很多优化,包括:对v元素进行一次排序,缓存exp(v),并从递归中切换到循环。

是否有可能证明这种基于递归的算法可以进行大量优化,在某种意义上在数值上是最优的?具体来说,当最终转换为有限精度时,任意精度计算会下溢/溢出时,它最小化舍入误差并且只有下溢/溢出。

这是C中的一个具体实现,它可以提高性能(logsumexp(v) = logsumexp(greaterhalf(v)) + log1p( sum(exp(lesserhalf(v))) / sum(exp(greaterhalf(v))) )):

#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>

#define MAX(a, b) \
    ({ __typeof__ (a) x = (a); \
        __typeof__ (b) y = (b); \
        x > y ? x : y; })
#define MIN(a, b) \
    ({ __typeof__ (a) x = (a); \
        __typeof__ (b) y = (b); \
        x <= y ? x : y; })

//Swap the normal comparison order to get a decreasing sort
int doubleComp( const void *p1, const void *p2 ) {
    double *d1 = (double *)p1;
    double *d2 = (double *)p2;

    if (*d1 < *d2) return -1;
    if (*d1 > *d2) return 1;
    if (*d1 == *d2) return 0;
    return 1;//Moves NANs to the end of the list
}
double LogSumExp( double *restrict Inputs, size_t lenInputs ) {
    double biggestInput, result;
    size_t i;

    //Preflight for NANs
    for ( i = 0; i < lenInputs; i++ ) {
        if ( isnan(Inputs[i]) ) {
            return NAN;
        }
    }

    if ( lenInputs == 2 ) {
        biggestInput = MAX( Inputs[0], Inputs[1] );
        result = biggestInput + log1p( exp( MIN(Inputs[0], Inputs[1]) - biggestInput ) );
    }

    else if ( lenInputs > 2 ) {
        double *restrict sortedInputs, *restrict Sums, *restrict curSum;
        size_t lenSums;
        double bigpart;
        size_t iSum, startpoint, stoppoint;

        //Allocate needed memory
        lenSums = (unsigned int) ceil( log2((double) lenInputs) ) + 1;
        sortedInputs = (double *) malloc( lenInputs * sizeof (double) );
        Sums = (double *) malloc( lenSums * sizeof (double) );
        if ( sortedInputs == NULL || Sums == NULL ) {
            fprintf(stderr, "Memory allocation failed in LogSumExp.\n");
            abort();
        }

        //Sort the inputs, without disturbing the actual Inputs
        memcpy( sortedInputs, Inputs, lenInputs * sizeof (double) );
        qsort( sortedInputs, lenInputs, sizeof (double),
            doubleComp );

        //Subtract the biggest input to control possible overflow
        biggestInput = sortedInputs[lenInputs - 1];
        for ( i = 0; i < lenInputs; ++i ) {
            sortedInputs[i] -= biggestInput;
        }

        //Produce the intermediate Sums
        stoppoint = 0;
        for ( iSum = 0; iSum < lenSums; iSum++ ) {
            curSum = &( Sums[iSum] );
            *curSum = 0.0;

            startpoint = stoppoint;
            stoppoint = lenInputs - startpoint;
            stoppoint = ( stoppoint >> 1 );
            stoppoint = MAX( stoppoint, 1 );
            stoppoint += startpoint;

            for ( i = startpoint; i < stoppoint; i++ ) {
                *curSum += exp( sortedInputs[i] );
            }
        }

        //Digest the Sums into results
        result = 0.0;
        for ( iSum = 0; iSum < lenSums - 1; iSum++ ) {
            bigpart = 0.0;
            for ( i = iSum + 1; i < lenSums; i++ ) {
                bigpart += Sums[i];
            }

            if ( Sums[iSum] > 0.0 ) {
                result += log1p( Sums[iSum] / bigpart );
            }
        }

        free( Sums );
        free( sortedInputs );

        result += biggestInput;
    }

    else if ( lenInputs == 1 ) {
        result = Inputs[0];
    }

    else {
        result = NAN;
    }

    return result;
}

同样的python实现(取决于numpy):

from numpy import *

def LogSumExp( inputs ):
    if any( isnan(inputs) ) :
        result = float("nan")

    elif type(inputs) == type(float(1.0)) or len(inputs) == 1:
        result = inputs

    elif len(inputs) == 2:
        smallval, bigval = ( min(inputs), max(inputs) )
        result = bigval + log1p( exp( smallval - bigval ) )

    elif len(inputs) > 2:
        srtInputs = sort( inputs )
        bigval = srtInputs[-1]
        srtInputs -= bigval
        expInputs = exp( srtInputs )

        result = 0.0
        startpoint = 0
        endpoint = len(inputs) // 2
        while startpoint < len(inputs) - 1:
            smallpart = sum(expInputs[startpoint:endpoint])
            if smallpart > 0.0:
                result += log1p( smallpart / \
                                 sum(expInputs[endpoint:]) )

            startpoint = endpoint
            endpoint = len(inputs) - startpoint
            endpoint = endpoint // 2
            endpoint = max( endpoint, 1 )
            endpoint += startpoint

        result += bigval

    else:
        result = float("nan")

    return result

0 个答案:

没有答案