我正在尝试使用缩放的Baum-Welch训练一个双态隐马尔可夫模型,但是当我的发射序列太小时我注意到了。我的概率转向java中的NaN。这是正常的吗?我在下面的java中发布了我的代码:
import java.util.ArrayList;
/*
Scaled Baum-Welch Algorithm implementation
author: Ricky Chang
*/
public class HMModeltest {
public static double[][] stateTransitionMatrix = new double[2][2]; // State Transition Matrix
public static double[][] emissionMatrix; // Emission Probability Matrix
public static double[] pi = new double[2]; // Initial State Distribution
double[] scaler; // This is used for scaling to prevent underflow
private static int emissions_id = 1; // To identify if the emissions are for the price changes or spread changes
private static int numEmissions = 0; // The amount of emissions
private static int numStates = 2; // The number of states in hmm
public static double improvementVar; // Used to assess how much the model has improved
private static double genState; // Generated state, it is used to generate observations below
// Create an ArrayList to store the emissions
public static ArrayList<Integer> eSequence = new ArrayList<Integer>();
// Initialize H, emission_id: 1 is price change, 2 are spreads; count is for the amount of different emissions
public HMModeltest(int id, int count){
emissions_id = id;
numEmissions = count;
stateTransitionMatrix = set2DValues(numStates,numStates); // Give the STM row stochastic values
emissionMatrix = new double[numStates][numEmissions];
emissionMatrix = set2DValues(numStates,numEmissions); // Give the Emission probability matrix row stochastic values
pi = set1DValues(numStates); // Give the initial matrix row stochastic values
}
// Categorize the price change emissions; I may want to put these in the Implementation.
private int identifyE1(double e){
if( e == 0) return 4;
if( e > 0){
if(e == 1) return 5;
else if(e == 3) return 6;
else if(e == 5) return 7;
else return 8;
}
else{
if(e == -1) return 3;
else if(e == -3) return 2;
else if(e == -5) return 1;
else return 0;
}
}
// Categorize the spread emissions
private int identifyE2(double e){
if(e == 1) return 0;
else if(e == 3) return 1;
else return 2;
}
public void updateE(int emission){
if(emissions_id == 1) eSequence.add( identifyE1(emission) );
else eSequence.add( identifyE2(emission) );
}
// Used to intialize random row stochastic values to vectors
private double[] set1DValues(int col){
double sum = 0;
double temp = 0;
double [] returnVector = new double[col];
for(int i = 0; i < col; i++){
temp = Math.round(Math.random() * 1000);
returnVector[i] = temp;
sum = sum + temp;
}
for(int i = 0; i < col; i++){
returnVector[i] = returnVector[i] / sum;
}
return returnVector;
}
// Used to initialize random row stochastic values to matrices
public double[][] set2DValues(int row, int col){
double sum = 0;
double temp = 0;
double[][] returnMatrix = new double[row][col];
for(int i = 0; i < row; i++){
for(int j = 0; j < col; j++){
temp = Math.round(Math.random() * 1000);
returnMatrix[i][j] = temp;
sum = sum + temp;
}
for(int j = 0; j < col; j++){
returnMatrix[i][j] = returnMatrix[i][j] / sum;
}
sum = 0;
}
return returnMatrix;
}
// Use forward algorithm to calculate alpha for all states and times
public double[][] forwardAlgo(int time){
double alpha[][] = new double[numStates][time];
scaler = new double[time];
// Intialize alpha for time 0
scaler[0] = 0; // c0 is for scaling purposes to avoid underflow
for(int i = 0; i < numStates; i ++){
alpha[i][0] = pi[i] * emissionMatrix[i][eSequence.get(0)];
scaler[0] = scaler[0] + alpha[i][0];
}
// Scale alpha_0
scaler[0] = 1 / scaler[0];
for(int i = 0; i < numStates; i++){
alpha[i][0] = scaler[0] * alpha[i][0];
}
// Use recursive method to calculate alpha
double tempAlpha = 0;
for(int t = 1; t < time; t++){
scaler[t] = 0;
for(int i = 0; i < numStates; i++){
for(int j = 0; j < numStates; j++){
tempAlpha = tempAlpha + alpha[j][t-1] * stateTransitionMatrix[j][i];
}
alpha[i][t] = tempAlpha * emissionMatrix[i][eSequence.get(t)];
scaler[t] = scaler[t] + alpha[i][t];
tempAlpha = 0;
}
scaler[t] = 1 / scaler[t];
for(int i = 0; i < numStates; i++){
alpha[i][t] = scaler[t] * alpha[i][t];
}
}
System.out.format("scaler: ");
for(int t = 0; t < time; t++){
System.out.format("%f, ", scaler[t]);
}
System.out.print('\n');
return alpha;
}
// Use backward algorithm to calculate beta for all states
public double[][] backwardAlgo(int time){
double beta[][] = new double[2][time];
// Intialize beta for current time
for(int i = 0; i < numStates; i++){
beta[i][time-1] = scaler[time-1];
}
// Use recursive method to calculate beta
double tempBeta = 0;
for(int t = time-2; t >= 0; t--){
for(int i = 0; i < numStates; i++){
for(int j = 0; j < numStates; j++){
tempBeta = tempBeta + (stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]);
}
beta[i][t] = tempBeta;
beta[i][t] = scaler[t] * beta[i][t];
tempBeta = 0;
}
}
return beta;
}
// Calculate the probability of emission sequence given the model (it is also the denominator to calculate gamma and digamma)
public double calcP(int t, double[][] alpha, double[][] beta){
double p = 0;
for(int i = 0; i < numStates; i++){
for(int j = 0; j < numStates; j++){
p = p + (alpha[i][t] * stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]);
}
}
return p;
}
// Calculate digamma; i and j are both states
public double calcDigamma(double p, int t, int i, int j, double[][] alpha, double[][] beta){
double digamma = (alpha[i][t] * stateTransitionMatrix[i][j] * emissionMatrix[j][eSequence.get(t+1)] * beta[j][t+1]) / p;
return digamma;
}
public void updatePi(double[][] gamma){
for(int i = 0; i < numStates; i++){
pi[i] = gamma[i][0];
}
}
public void updateAll(){
int time = eSequence.size();
double alpha[][] = forwardAlgo(time);
double beta[][] = backwardAlgo(time);
double initialp = calcLogEProb(time);
double nextState0, nextState1;
double p = 0;
double[][][] digamma = new double[numStates][numStates][time];
double[][] gamma = new double[numStates][time];
for(int t = 0; t < time-1; t++){
p = calcP(t, alpha, beta);
for(int i = 0; i < numStates; i++){
gamma[i][t] = 0;
for(int j = 0; j < numStates; j++){
digamma[i][j][t] = calcDigamma(p, t, i, j, alpha, beta);
gamma[i][t] = gamma[i][t] + digamma[i][j][t];
}
}
}
updatePi(gamma);
updateA(digamma, gamma);
updateB(gamma);
alpha = forwardAlgo(time);
double postp = calcLogEProb(time);
improvementVar = postp - initialp;
}
// Update the state transition matrix
public void updateA(double[][][] digamma, double[][] gamma){
int time = eSequence.size();
double num = 0;
double denom = 0;
for(int i = 0; i < numStates; i++){
for(int j = 0; j < numStates; j++){
for(int t = 0; t < time-1; t++){
num = num + digamma[i][j][t];
denom = denom + gamma[i][t];
}
stateTransitionMatrix[i][j] = num/denom;
num = 0;
denom = 0;
}
}
}
public void updateB(double[][] gamma){
int time = eSequence.size();
double num = 0;
double denom = 0;
// k is an observation, j is a state, t is time
for(int i = 0; i < numStates; i++){
for(int k = 0; k < numEmissions; k++){
for(int t = 0; t < time-1; t++){
if( eSequence.get(t) == k) num = num + gamma[i][t];
denom = denom + gamma[i][t];
}
emissionMatrix[i][k] = num/denom;
num = 0;
denom = 0;
}
}
}
public double calcLogEProb(int time){
double logProb = 0;
for(int t = 0; t < time; t++){
logProb = logProb + Math.log(scaler[t]);
}
return -logProb;
}
public double calcNextState(int time, int state, double[][] gamma){
double p = 0;
for(int i = 0; i < numStates; i++){
for(int j = 0; j < numStates; j++){
p = p + gamma[i][time-2] * stateTransitionMatrix[i][j] * stateTransitionMatrix[j][state];
}
}
return p;
}
// Print parameters
public void print(){
System.out.println("Pi:");
System.out.print('[');
for(int i = 0; i < 2; i++){
System.out.format("%f, ", pi[i]);
}
System.out.print(']');
System.out.print('\n');
System.out.println("A:");
for(int i = 0; i < 2; i++){
System.out.print('[');
for(int j = 0; j < 2; j++){
System.out.format("%f, ", stateTransitionMatrix[i][j]);
}
System.out.print(']');
System.out.print('\n');
}
System.out.println("B:");
for(int i = 0; i < 2; i++){
System.out.print('[');
for(int j = 0; j < 9; j++){
System.out.format("%f, ", emissionMatrix[i][j]);
}
System.out.print(']');
System.out.print('\n');
}
System.out.print('\n');
}
/* Generate sample data to test HMM training with the following params:
* [ .3, .7 ]
* [ .8, .2 ] [ .45 .1 .08 .05 .03 .02 .05 .2 .02 ]
* [ .36 .02 .06 .15 .04 .05 .2 .1 .02 ]
* With these as observations: {-10, -5, -3, -1, 0, 1, 3, 5, 10}
*/
public static int sampleDataGen(){
double rand = 0;
rand = Math.random();
if(genState == 1){
if(rand < .3) genState = 1;
else genState = 2;
rand = Math.random();
if(rand < .45) return -10;
else if(rand < .55) return -5;
else if(rand < .63) return -3;
else if(rand < .68) return -1;
else if(rand < .71) return 0;
else if(rand < .73) return 1;
else if(rand < .78) return 3;
else if(rand < .98) return 5;
else return 10;
}
else {
if(rand < .8) genState = 1;
else genState = 2;
rand = Math.random();
if(rand < .36) return -10;
else if(rand < .38) return -5;
else if(rand < .44) return -3;
else if(rand < .59) return -1;
else if(rand < .63) return 0;
else if(rand < .68) return 1;
else if(rand < .88) return 3;
else if(rand < .98) return 5;
else return 10;
}
}
public static void main(String[] args){
HMModeltest test = new HMModeltest(1,9);
test.print();
System.out.print('\n');
for(int i = 0; i < 20; i++){
test.updateE(sampleDataGen());
}
test.updateAll();
System.out.print('\n');
test.print();
System.out.print('\n');
for(int i = 0; i < 10; i++){
test.updateE(sampleDataGen());
}
test.updateAll();
System.out.print('\n');
test.print();
System.out.print('\n');
}
}
我的猜测是,由于样本太小,有时某些观察结果不存在概率。但是有一些确认会很好。
答案 0 :(得分:1)
你可以参考&#34; Scaling&#34; Rabiner's paper中的部分,解决了下溢问题。
您也可以在日志空间中进行计算,这是HTK和R的作用。乘法和除法成为加法和减法。对于其他两个,请查看相应工具包中的LAdd
/ LSub
和logspace_add
/ logspace_sub
函数。
log-sum-exp技巧也可能有所帮助。