使用Accord.net在QLearning中捕获状态为数组

时间:2017-08-31 14:37:24

标签: reinforcement-learning accord.net q-learning

我正在尝试将QLearning实施到Unity中的模拟蚂蚁。按照Accord的Animat示例,我设法实现了算法的要点。

现在我的代理有5个状态输入 - 其中三个来自检测前方障碍物的传感器(Unity中的RayCasts),其余两个是地图上的X和Y位置。

我的问题是qLearning.GetAction(currentState)只将int作为参数。如何使用数组(或Tensor)为代理当前状态实现我的算法?

这是我的代码:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Accord.MachineLearning;
using System;

public class AntManager : MonoBehaviour {
    float direction = 0.01f;
    float rotation = 0;

    // learning settings
    int learningIterations = 100;
    private double explorationRate = 0.5;
    private double learningRate = 0.5;

    private double moveReward = 0;
    private double wallReward = -1;
    private double goalReward = 1;

    private float lastDistance = 0;

    private RaycastHit hit;
    private int hitInteger = 0;

    // Q-Learning algorithm
    private QLearning qLearning = null;


    // Use this for initialization
    void Start () {
        qLearning = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
    }

    // Update is called once per frame
    void Update () {        

        // curent coordinates of the agent
        float agentCurrentX = transform.position.x;
        float agentCurrentY = transform.position.y;
        // exploration policy
        TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy;

        EpsilonGreedyExploration explorationPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy;

        // set exploration rate for this iteration
        explorationPolicy.Epsilon = explorationRate - learningIterations * explorationRate;
        // set learning rate for this iteration
        qLearning.LearningRate = learningRate -   learningIterations * learningRate;
        // clear tabu list
        tabuPolicy.ResetTabuList();

        // get agent's current state
        int currentState = ((int)Math.Round(transform.position.x, 0) + (int)Math.Round(transform.position.y, 0) + hitInteger);
        // get the action for this state
        int action = qLearning.GetAction(currentState);
        // update agent's current position and get his reward
        double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action);
        // get agent's next state
        int nextState = currentState;
        // do learning of the agent - update his Q-function
        qLearning.UpdateState(currentState, action, reward, nextState);

        // set tabu action
        tabuPolicy.SetTabuAction((action + 2) % 4, 1);


    }

    // Update agent position and return reward for the move
    private double UpdateAgentPosition(ref float currentX, ref float currentY, int action)
    {
        // default reward is equal to moving reward
        double reward = moveReward;
        GameObject food = GameObject.FindGameObjectWithTag("Food");

        float distance = Vector3.Distance(transform.position, food.transform.position);

        if (distance < lastDistance)
            reward = 0.2f;

        lastDistance = distance;

        Debug.Log(distance);

        switch (action)
        {
            case 0:         // go to north (up)
                rotation += -1f;
                break;
            case 1:         // go to east (right)
                rotation += 1f;
                break;
            case 2:         // go to south (down)
                rotation += 1f;
                break;
            case 3:         // go to west (left)
                rotation += -1f;
                break;
        }

        //transform.eulerAngles = new Vector3(10, rotation, 0);
        transform.Rotate(0, rotation * Time.deltaTime, 0);
        transform.Translate(new Vector3(0, 0, 0.01f));



        float newX = transform.localRotation.x;
        float newY = transform.localRotation.y;

        Ray sensorForward = new Ray(transform.position, transform.forward);
        Debug.DrawRay(transform.position, transform.forward * 1);

        if (Physics.Raycast(sensorForward, out hit, 1))
        {
            if (hit.collider.tag != "Terrain")
            {
                Debug.Log("Sensor Forward hit!");

                reward = wallReward;
            }
            if (hit.collider.tag == "Food")
            {
                Debug.Log("Sensor Found Food!");
                Destroy(food);
                reward = goalReward;
                hitInteger = 1;
            }
            hitInteger = 0;
        }

        return reward;
    }
}

1 个答案:

答案 0 :(得分:0)

documentation以此为例:

c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7)

这似乎是将两个值整数按位移位到状态的二进制编码中。你的代码可能需要这样的东西:

int currentState = ((int)Math.Round(transform.position.x, 0) | ((int)Math.Round(transform.position.y, 0) << 1) | (hitInteger << 2))

但是,您首先需要将状态映射为二进制变量,因此此代码仅适用于2x2网格。即使示例声明了整数,它们也是二进制值:将值移位2或更多位置毫无意义。

可视化状态的一种有用方法是直接查看二进制文件:

Convert.ToString(1 | (0 << 1) | (1 << 2), 2)