强化学习中的政策迭代问题

时间:2012-08-19 14:47:25

标签: java reinforcement-learning

我必须解决策略迭代的问题,模型显示在

enter image description here

我制作了一个Java程序进行模拟,策略算法基于Sutton和Barto关于Reinforcement learning的书。 我相信java程序中的模型与图中的模型一致。 当我最终运行模拟时,我可以进行4次迭代,最终结果是正确的,并在教科书中给出最终答案。但是在教科书中,最终的答案只需要三次迭代,虽然最终的答案是正确的,迭代中的中间答案有微小的变化,我不知道是什么问题?

public class PolicyIteration{

    private static final double TOLERANCE = 0.1;
    private double gamma = 0.9;

    private int stateCount = 5;
    private int actionCount = 7;

    //states names
    private final String s[] = {"s1","s2","s3","s4","s5"};
    private final String ac[] = {"a","b","c","d","e","f","g"};
    // states
    private final int s1 =0;
    private final int s2 =1;
    private final int s3 = 2;
    private final int s4 = 3;
    private final int s5 = 4;

    // actions 
    private final int a = 0;
    private final int b = 1;
    private final int c = 2;
    private final int d = 3;
    private final int e = 4;
    private final int f = 5;
    private final int g = 6;

    // transition
    private double t[][][] = new double[stateCount][stateCount][actionCount];
    // rewards
    private double r[][][] = new double[stateCount][stateCount][actionCount];    
    // utility
    private double values[] = new double[stateCount];
    // policy
    private int policy[] = new int[stateCount];
    // init

    private void init(){        
        for(int i=0;i<values.length;i++){
            values[i] = 0;
        }       
        //(s1; a); (s2; c); (s3; e); (s4; f); (s5; g)
        policy[s1] = a; policy[s2] = c; policy[s3] = e; policy[s4] = f;  policy[s5] = g;
        //
        t[s1][s2][a] = 1;  r[s1][s2][a] = 0;
        t[s1][s3][b] = 1;  r[s1][s3][b] = 0;
        t[s2][s2][c] = 0.5;  r[s2][s2][c] = 0;
        t[s2][s4][c] = 0.5;  r[s2][s4][c] = 0;
        t[s3][s3][e] = 1;  r[s3][s3][e] = 1;
        t[s4][s5][f] = 1;  r[s4][s5][f] = 0;
        t[s4][s4][d] = 1;  r[s4][s4][d] = 10;
        t[s5][s5][g] = 1;  r[s5][s5][g] = 0;
    }

    public static void main(String args[]){

        PolicyIteration p = new PolicyIteration();
        p.init();
        p.run();
    }


    public void run(){
        int it = 0;

        int changed= 1;
        do{
            it++;
            System.out.println("Iteration :"+it);
            changed = train();

            for(int i=0;i<policy.length;i++){
                System.out.print( s[i]+"->" + ac[ policy[i] ]  +" ; ");
            }

            System.out.println();

            for(int i=0;i<policy.length;i++)
            System.out.print(values[i]+" ,");

            System.out.println();
        }while(changed>0);

    }

    public int train() {

        boolean valuesChanged = false;
        do {
            valuesChanged = false;
            // loop through all the states
            for (int i = 0; i < stateCount; i++) {
                // calculate the new value
                int action = policy[i];
                double actionVal = 0;
                for (int j = 0; j < stateCount; j++) {
                    actionVal += t[i][j][action]*(r[i][j][action] + gamma* values[j]);
                }
                // check if we're done
                if (Math.abs(values[i] - actionVal) > TOLERANCE) {
                    valuesChanged = true;
                }
                values[i] = actionVal;
            }
        } while (valuesChanged);
        int changed = 0;
        // calculate the new policy
        for (int i = 0; i < stateCount; i++) {
            // find the maximum action
            double maxActionVal = -Double.MAX_VALUE;
            int maxAction = 0;
            for (int action = 0; action < actionCount; action++) {
                double actionVal = 0;
                for (int j = 0; j < stateCount; j++) {
                    actionVal += t[i][j][action]*(r[i][j][action] + gamma* values[j]);                          
                }                
                if (actionVal >= maxActionVal) {
                    maxActionVal = actionVal;
                    maxAction = action;
                }

                if(i==s5 && action == g){
                    System.out.println("----  actionVal:"+actionVal+"    "+action);
                }
            }
            if (policy[i] != maxAction) {
                changed++;
                policy[i] = maxAction;
            }
        }
        return changed;
    }


}

教科书中的答案是

Initial policy:
s1 best action: a
s2 best action: c
s3 best action: e
s4 best action: f
s5 best action: g
After policy evaluation:
s1 value: 0.000
s2 value: 0.000
s3 value: 9.114
s4 value: 0.000
s5 value: 0.000
== End of iteration 1 ==

new policy:
s1 best action: b
s2 best action: c
s3 best action: e
s4 best action: d
s5 best action: g
After policy evaluation:
s1 value: 8.992
s2 value: 80.945
s3 value: 9.992
s4 value: 99.127
s5 value: 0.000
== End of iteration 2 ==

new policy:
s1 best action: a
s2 best action: c
s3 best action: e
s4 best action: d
s5 best action: g
After policy evaluation:
s1 value: 72.929
s2 value: 81.111
s3 value: 9.994
s4 value: 99.293
s5 value: 0.000
== End of iteration 3 ==

0 个答案:

没有答案