当我的发射序列太小时,我的hmm中的状态转移概率和发射概率会收敛到纳米

时间:2014-02-23 21:54:58

标签: java machine-learning hidden-markov-models

我正在尝试使用缩放的Baum-Welch训练一个双态隐马尔可夫模型,但是当我的发射序列太小时我注意到了。我的概率转向java中的NaN。这是正常的吗?我在下面的java中发布了我的代码:

import java.util.ArrayList;
/*
Scaled Baum-Welch Algorithm implementation
author: Ricky Chang
*/

public class HMModeltest {

public static double[][] stateTransitionMatrix = new double[2][2]; // State Transition Matrix
public static double[][] emissionMatrix; // Emission Probability Matrix
public static double[] pi = new double[2]; // Initial State Distribution

double[] scaler; // This is used for scaling to prevent underflow
private static int emissions_id = 1; // To identify if the emissions are for the price changes or spread changes
private static int numEmissions = 0; // The amount of emissions
private static int numStates = 2; // The number of states in hmm
public static double improvementVar; // Used to assess how much the model has improved
private static double genState; // Generated state, it is used to generate observations below

// Create an ArrayList to store the emissions
public static ArrayList<Integer> eSequence = new ArrayList<Integer>();


// Initialize H, emission_id: 1 is price change, 2 are spreads; count is for the amount of different emissions
public HMModeltest(int id, int count){
    emissions_id = id;
    numEmissions = count;

    stateTransitionMatrix = set2DValues(numStates,numStates); // Give the STM  row stochastic values
    emissionMatrix = new double[numStates][numEmissions];
    emissionMatrix = set2DValues(numStates,numEmissions); // Give the Emission probability matrix row stochastic values
    pi = set1DValues(numStates); // Give the initial matrix row stochastic values 
}

// Categorize the price change emissions; I may want to put these in the Implementation.
private int identifyE1(double e){

    if( e == 0) return 4;
    if( e > 0){
        if(e == 1) return 5;
        else if(e == 3) return 6;
        else if(e == 5) return 7;
        else return 8;
    }
    else{
        if(e == -1) return 3;
        else if(e == -3) return 2;
        else if(e == -5) return 1;
        else return 0;
    }
}

// Categorize the spread emissions
private int identifyE2(double e){

    if(e == 1) return 0;
    else if(e == 3) return 1;
    else return 2;
}

public void updateE(int emission){
    if(emissions_id == 1) eSequence.add( identifyE1(emission) );
    else eSequence.add( identifyE2(emission) );
}

// Used to intialize random row stochastic values to vectors
private double[] set1DValues(int col){
    double sum = 0;
    double temp = 0;
    double [] returnVector = new double[col];

    for(int i = 0; i < col; i++){
        temp = Math.round(Math.random() * 1000);
        returnVector[i] = temp;
        sum = sum + temp;
    }
    for(int i = 0; i < col; i++){
        returnVector[i] = returnVector[i] / sum;
    }

    return returnVector;
}

// Used to initialize random row stochastic values to matrices
public double[][] set2DValues(int row, int col){
    double sum = 0;
    double temp = 0;
    double[][] returnMatrix = new double[row][col];

    for(int i = 0; i < row; i++){
        for(int j = 0; j < col; j++){
            temp = Math.round(Math.random() * 1000);
            returnMatrix[i][j] = temp;
            sum = sum + temp;
        }
        for(int j = 0; j < col; j++){
            returnMatrix[i][j] = returnMatrix[i][j] / sum;
        }

        sum = 0;
    }

    return returnMatrix;
}

// Use forward algorithm to calculate alpha for all states and times
public double[][] forwardAlgo(int time){
    double alpha[][] = new double[numStates][time];
    scaler = new double[time];

    // Intialize alpha for time 0
    scaler[0] = 0; // c0 is for scaling purposes to avoid underflow
    for(int i = 0; i < numStates; i ++){
        alpha[i][0] = pi[i] * emissionMatrix[i][eSequence.get(0)];
        scaler[0] = scaler[0] + alpha[i][0];
    }

    // Scale alpha_0
    scaler[0] = 1 / scaler[0];
    for(int i = 0; i < numStates; i++){
        alpha[i][0] = scaler[0] * alpha[i][0];
    }

    // Use recursive method to calculate alpha
    double tempAlpha = 0;
    for(int t = 1; t < time; t++){
        scaler[t] = 0;
        for(int i = 0; i < numStates; i++){
            for(int j = 0; j < numStates; j++){
                tempAlpha = tempAlpha + alpha[j][t-1] * stateTransitionMatrix[j][i];
            }
            alpha[i][t] = tempAlpha * emissionMatrix[i][eSequence.get(t)];
            scaler[t] = scaler[t] + alpha[i][t];
            tempAlpha = 0;
        }

        scaler[t] = 1 / scaler[t];
        for(int i = 0; i < numStates; i++){
            alpha[i][t] = scaler[t] * alpha[i][t];
        }
    }

    System.out.format("scaler: ");
    for(int t = 0; t < time; t++){
        System.out.format("%f, ", scaler[t]);
    }
    System.out.print('\n');
    return alpha;
}

// Use backward algorithm to calculate beta for all states
public double[][] backwardAlgo(int time){
    double beta[][] = new double[2][time];

    // Intialize beta for current time
    for(int i = 0; i < numStates; i++){
        beta[i][time-1] = scaler[time-1];
    }

    // Use recursive method to calculate beta
    double tempBeta = 0;
    for(int t = time-2; t >= 0; t--){
        for(int i = 0; i < numStates; i++){
            for(int j = 0; j < numStates; j++){
                tempBeta = tempBeta + (stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]);
            }
            beta[i][t] = tempBeta;
            beta[i][t] = scaler[t] * beta[i][t];
            tempBeta = 0;
        }
    }

    return beta;
}

// Calculate the probability of emission sequence given the model (it is also the denominator to calculate gamma and digamma)
public double calcP(int t, double[][] alpha, double[][] beta){

    double p = 0;

    for(int i = 0; i < numStates; i++){
        for(int j = 0; j < numStates; j++){
            p = p + (alpha[i][t] * stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]);
        }
    }
    return p;
}

// Calculate digamma; i and j are both states
public double calcDigamma(double p, int t, int i, int j, double[][] alpha, double[][] beta){
    double digamma = (alpha[i][t] * stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]) / p;
    return digamma;
}

public void updatePi(double[][] gamma){
    for(int i = 0; i < numStates; i++){
        pi[i] = gamma[i][0];
    }
}

public void updateAll(){
    int time = eSequence.size();
    double alpha[][] = forwardAlgo(time);
    double beta[][] = backwardAlgo(time);
    double initialp = calcLogEProb(time);
    double nextState0, nextState1;

    double p = 0;
    double[][][] digamma = new double[numStates][numStates][time];
    double[][] gamma = new double[numStates][time];

    for(int t = 0; t < time-1; t++){
        p = calcP(t, alpha, beta);
        for(int i = 0; i < numStates; i++){
            gamma[i][t] = 0;
            for(int j = 0; j < numStates; j++){
                digamma[i][j][t] = calcDigamma(p, t, i, j, alpha, beta);
                gamma[i][t] = gamma[i][t] + digamma[i][j][t];
            }
        }
    }

    updatePi(gamma);
    updateA(digamma, gamma);
    updateB(gamma);

    alpha = forwardAlgo(time);
    double postp = calcLogEProb(time);
    improvementVar = postp - initialp;
}

// Update the state transition matrix
public void updateA(double[][][] digamma, double[][] gamma){
    int time = eSequence.size();
    double num = 0;
    double denom = 0;

    for(int i = 0; i < numStates; i++){
        for(int j = 0; j < numStates; j++){
            for(int t = 0; t < time-1; t++){
                num = num + digamma[i][j][t];
                denom = denom + gamma[i][t];
            }
            stateTransitionMatrix[i][j] = num/denom;
            num = 0;
            denom = 0;
        }
    }
}

public void updateB(double[][] gamma){
    int time = eSequence.size();
    double num = 0;
    double denom = 0;

    // k is an observation, j is a state, t is time
    for(int i = 0; i < numStates; i++){
        for(int k = 0; k < numEmissions; k++){
            for(int t = 0; t < time-1; t++){
                if( eSequence.get(t) == k) num = num + gamma[i][t];
                denom = denom + gamma[i][t];
            }
            emissionMatrix[i][k] = num/denom;
            num = 0;
            denom = 0;
        }
    }
}

public double calcLogEProb(int time){
    double logProb = 0;

    for(int t = 0; t < time; t++){
        logProb = logProb + Math.log(scaler[t]);
    }

    return -logProb;
}

public double calcNextState(int time, int state, double[][] gamma){
    double p = 0;
    for(int i = 0; i < numStates; i++){
        for(int j = 0; j < numStates; j++){
            p = p + gamma[i][time-2] * stateTransitionMatrix[i][j] * stateTransitionMatrix[j][state];
        }
    }

    return p;
}

// Print parameters
public void print(){
    System.out.println("Pi:");
    System.out.print('[');
    for(int i = 0; i < 2; i++){
        System.out.format("%f, ", pi[i]);
    }
    System.out.print(']');
    System.out.print('\n');

    System.out.println("A:");
    for(int i = 0; i < 2; i++){
        System.out.print('[');
        for(int j = 0; j < 2; j++){
            System.out.format("%f, ", stateTransitionMatrix[i][j]);
        }
        System.out.print(']');
        System.out.print('\n');
    }

    System.out.println("B:");
    for(int i = 0; i < 2; i++){
        System.out.print('[');
        for(int j = 0; j < 9; j++){
            System.out.format("%f, ", emissionMatrix[i][j]);
        }
        System.out.print(']');
        System.out.print('\n');
    }
    System.out.print('\n');
}

/* Generate sample data to test HMM training with the following params:
 * [ .3, .7 ]
 * [ .8, .2 ]                       [ .45 .1  .08 .05 .03 .02 .05 .2 .02 ]
 *                                  [ .36 .02 .06 .15 .04 .05  .2 .1 .02 ]
 * With these as observations:        {-10, -5, -3, -1, 0, 1, 3, 5, 10}
 */
public static int sampleDataGen(){
    double rand = 0;

    rand = Math.random();
    if(genState == 1){
        if(rand < .3) genState = 1;
        else genState = 2;

        rand = Math.random();
        if(rand < .45) return -10;
        else if(rand < .55) return -5;
        else if(rand < .63) return -3;
        else if(rand < .68) return -1;
        else if(rand < .71) return 0;
        else if(rand < .73) return 1;
        else if(rand < .78) return 3;
        else if(rand < .98) return 5;
        else return 10;
    }
    else {
        if(rand < .8) genState = 1;
        else genState = 2;

        rand = Math.random();
        if(rand < .36) return -10;
        else if(rand < .38) return -5;
        else if(rand < .44) return -3;
        else if(rand < .59) return -1;
        else if(rand < .63) return 0;
        else if(rand < .68) return 1;
        else if(rand < .88) return 3;
        else if(rand < .98) return 5;
        else return 10;
    }
}


public static void main(String[] args){
    HMModeltest test = new HMModeltest(1,9);
    test.print();

    System.out.print('\n');
    for(int i = 0; i < 20; i++){
        test.updateE(sampleDataGen());
    }

    test.updateAll();
    System.out.print('\n');
    test.print();
    System.out.print('\n');


    for(int i = 0; i < 10; i++){
        test.updateE(sampleDataGen());
    }
    test.updateAll();
    System.out.print('\n');
    test.print();
    System.out.print('\n');
}

}

我的猜测是,由于样本太小,有时某些观察结果不存在概率。但是有一些确认会很好。

1 个答案:

答案 0 :(得分:1)

你可以参考&#34; Scaling&#34; Rabiner's paper中的部分,解决了下溢问题。

您也可以在日志空间中进行计算,这是HTK和R的作用。乘法和除法成为加法和减法。对于其他两个,请查看相应工具包中的LAdd / LSublogspace_add / logspace_sub函数。

log-sum-exp技巧也可能有所帮助。