ML代理-多个代理中断了培训

时间:2019-12-28 18:52:24

标签: c# unity3d machine-learning artificial-intelligence ml-agent

我一直在研究一种能使腰部保持在一定高度的自平衡剂。最近,我升级了“大腿”以允许3轴旋转,而不是以前的2轴旋转。完成此操作并修改ml代理代码以允许使用child sensors之后,这些代理现在似乎不再与一个以上的代理/区域一起使用。我不确定为什么会这样。需要明确的是,唯一的工作代理人表现得比平常“更具爆炸性” 。仅靠它自己时,尝试保持平衡要平静得多。也许我在这个过程中搞砸了其他东西?如果有人有什么想法,我什么都愿意。谢谢!

Broken ML Agents Inspector of agent

代理脚本:

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
using Random = UnityEngine.Random;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject floor;
    public GameObject sensor;
    public GameObject waist;
    public GameObject wFront;           //Used to check balance of waist.
    public GameObject wBack;           //Used to check balance of waist.
    public GameObject hipR;
    public GameObject hipL;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    public float bodyMoveSensitivity = 3f;

    public static GameObject[] bodyParts = new GameObject[11];
    public static HingeJoint[] hingeParts = new HingeJoint[11];
    public static JointLimits[] jntLimParts = new JointLimits[11];

    public static Vector3[] posStart = new Vector3[11];
    public static Vector3[] eulerStart = new Vector3[11];

    public void Start() {
        bodyParts = new GameObject[] { waist /*0*/, buttR /*1*/, buttL /*2*/, thighR /*3*/, thighL /*4*/, legR /*5*/, legL /*6*/, footR /*7*/, footL /*8*/, hipR /*9*/, hipL /*10*/};

        for (int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
            if (bodyParts[i].GetComponent<HingeJoint>() != null) {
                hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
                hingeParts[i].limits = jntLimParts[i];
            }
        }
    }

    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();
    }

    public override void AgentReset() {
        floor.transform.eulerAngles = new Vector3(Random.Range(-15, 15), 0, Random.Range(-15, 15));             //Floor rotation

        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
        waist.transform.eulerAngles = new Vector3(0, Random.Range(0, 360), 0);

        jntLimParts[1].max = 0;
        jntLimParts[1].min = jntLimParts[2].max - 1;
        hingeParts[1].limits = jntLimParts[2];

        jntLimParts[2].max = 0;
        jntLimParts[2].min = jntLimParts[2].max - 1;
        hingeParts[2].limits = jntLimParts[2];

        jntLimParts[3].max = 15;
        jntLimParts[3].min = jntLimParts[3].max - 1;
        hingeParts[3].limits = jntLimParts[3];

        jntLimParts[4].max = 15;
        jntLimParts[4].min = jntLimParts[4].max - 1;
        hingeParts[4].limits = jntLimParts[4];

        jntLimParts[5].max  = -15;
        jntLimParts[5].min = jntLimParts[5].max - 1;
        hingeParts[5].limits = jntLimParts[5];

        jntLimParts[6].max = -15;
        jntLimParts[6].min = jntLimParts[6].max - 1;
        hingeParts[6].limits = jntLimParts[6];

        jntLimParts[7].max = 15;
        jntLimParts[7].min = jntLimParts[7].max - 1;
        hingeParts[7].limits = jntLimParts[7];

        jntLimParts[8].max = 15;
        jntLimParts[8].min = jntLimParts[8].max - 1;
        hingeParts[8].limits = jntLimParts[8];

        jntLimParts[9].max = 0;
        jntLimParts[9].min = jntLimParts[9].max - 1;
        hingeParts[9].limits = jntLimParts[9];

        jntLimParts[10].max = 0;
        jntLimParts[10].min = jntLimParts[10].max - 1;
        hingeParts[10].limits = jntLimParts[10];
    }

    public override void AgentAction(float[] vectorAction) {

        float buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 1:
                buttRDir = 0;
                break;
            case 2:
                buttRDir = bodyMoveSensitivity;
                break;
            case 3:
                buttRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[1].max < 60 && jntLimParts[1].min > -60) {
            jntLimParts[1].max += buttRDir;
            jntLimParts[1].min = jntLimParts[1].max - 1;
            hingeParts[1].limits = jntLimParts[1];
        }
        else {
            if (jntLimParts[1].min <= -60) {
                jntLimParts[1].max = -58;

            }
            else if (jntLimParts[1].max >= 60) {
                jntLimParts[1].max = 59;
            }
            jntLimParts[1].min = jntLimParts[1].max - 1;
        }

        float buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 1:
                buttLDir = 0;
                break;
            case 2:
                buttLDir = bodyMoveSensitivity;
                break;
            case 3:
                buttLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[2].max < 60 && jntLimParts[2].min > -60) {
            jntLimParts[2].max += buttLDir;
            jntLimParts[2].min = jntLimParts[2].max - 1;
            hingeParts[2].limits = jntLimParts[2];
        }
        else {
            if (jntLimParts[2].min <= -60) {
                jntLimParts[2].max = -58;

            }
            else if (jntLimParts[2].max >= 60) {
                jntLimParts[2].max = 59;
            }
            jntLimParts[2].min = jntLimParts[2].max - 1;
        }

        float thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 1:
                thighRDir = 0;
                break;
            case 2:
                thighRDir = bodyMoveSensitivity;
                break;
            case 3:
                thighRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[3].max < 80 && jntLimParts[3].min > -80) {
            jntLimParts[3].max += thighRDir;
            jntLimParts[3].min = jntLimParts[3].max - 1;
            hingeParts[3].limits = jntLimParts[3];
        }
        else {
            if (jntLimParts[3].min <= -80) {
                jntLimParts[3].max = -78;

            }
            else if (jntLimParts[3].max >= 80) {
                jntLimParts[3].max = 79;
            }
            jntLimParts[3].min = jntLimParts[3].max - 1;
        }

        float thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 1:
                thighLDir = 0;
                break;
            case 2:
                thighLDir = bodyMoveSensitivity;
                break;
            case 3:
                thighLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[4].max < 80 && jntLimParts[4].min > -80) {
            jntLimParts[4].max += thighLDir;
            jntLimParts[4].min = jntLimParts[4].max - 1;
            hingeParts[4].limits = jntLimParts[4];
        }
        else {
            if (jntLimParts[4].min <= -80) {
                jntLimParts[4].max = -78;

            }
            else if (jntLimParts[4].max >= 80) {
                jntLimParts[4].max = 79;
            }
            jntLimParts[4].min = jntLimParts[4].max - 1;
        }

        float legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 1:
                legRDir = 0;
                break;
            case 2:
                legRDir = bodyMoveSensitivity;
                break;
            case 3:
                legRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[5].max < 5 && jntLimParts[5].min > -80) {
            jntLimParts[5].max += legRDir;
            jntLimParts[5].min = jntLimParts[5].max - 1;
            hingeParts[5].limits = jntLimParts[5];
        }
        else {
            if (jntLimParts[5].min <= -80) {
                jntLimParts[5].max = -78;

            }
            else if (jntLimParts[5].max >= 5) {
                jntLimParts[5].max = 4;
            }
            jntLimParts[5].min = jntLimParts[5].max - 1;
        }

        float legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 1:
                legLDir = 0;
                break;
            case 2:
                legLDir = bodyMoveSensitivity;
                break;
            case 3:
                legLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[6].max < 5 && jntLimParts[6].min > -80) {
            jntLimParts[6].max += legLDir;
            jntLimParts[6].min = jntLimParts[6].max - 1;
            hingeParts[6].limits = jntLimParts[6];
        }
        else {
            if (jntLimParts[6].min <= -80) {
                jntLimParts[6].max = -78;

            }
            else if (jntLimParts[6].max >= 5) {
                jntLimParts[6].max = 4;
            }
            jntLimParts[6].min = jntLimParts[6].max - 1;
        }

        float footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 1:
                footRDir = 0;
                break;
            case 2:
                footRDir = bodyMoveSensitivity;
                break;
            case 3:
                footRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[7].max < 50 && jntLimParts[7].min > -50) {
            jntLimParts[7].max += footRDir;
            jntLimParts[7].min = jntLimParts[7].max - 1;
            hingeParts[7].limits = jntLimParts[7];
        }
        else {
            if (jntLimParts[7].min <= -50) {
                jntLimParts[7].max = -48;

            }
            else if (jntLimParts[7].max >= 50) {
                jntLimParts[7].max = 49;
            }
            jntLimParts[7].min = jntLimParts[7].max - 1;
        }

        float footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 1:
                footLDir = 0;
                break;
            case 2:
                footLDir = bodyMoveSensitivity;
                break;
            case 3:
                footLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[8].max < 50 && jntLimParts[8].min > -50) {
            jntLimParts[8].max += footLDir;
            jntLimParts[8].min = jntLimParts[8].max - 1;
            hingeParts[8].limits = jntLimParts[8];
        }
        else {
            if (jntLimParts[8].min <= -50) {
                jntLimParts[8].max = -48;

            }
            else if (jntLimParts[8].max >= 50) {
                jntLimParts[8].max = 49;
            }
            jntLimParts[8].min = jntLimParts[8].max - 1;
        }

        float hipRDir = 0;
        int hipRVec = (int)vectorAction[9];
        switch (hipRVec) {
            case 1:
                hipRDir = 0;
                break;
            case 2:
                hipRDir = bodyMoveSensitivity;
                break;
            case 3:
                hipRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[9].max < 45 && jntLimParts[9].min > -45) {
            jntLimParts[9].max += hipRDir;
            jntLimParts[9].min = jntLimParts[9].max - 1;
            hingeParts[9].limits = jntLimParts[9];
        }
        else {
            if (jntLimParts[9].min <= -45) {
                jntLimParts[9].max = -43;

            }
            else if (jntLimParts[9].max >= 45) {
                jntLimParts[9].max = 44;
            }
            jntLimParts[9].min = jntLimParts[9].max - 1;
        }

        float hipLDir = 0;
        int hipLVec = (int)vectorAction[10];
        switch (hipRVec) {
            case 1:
                hipLDir = 0;
                break;
            case 2:
                hipLDir = bodyMoveSensitivity;
                break;
            case 3:
                hipLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[10].max < 45 && jntLimParts[10].min > -45) {
            jntLimParts[10].max += hipLDir;
            jntLimParts[10].min = jntLimParts[10].max - 1;
            hingeParts[10].limits = jntLimParts[10];
        }
        else {
            if (jntLimParts[10].min <= -45) {
                jntLimParts[10].max = -43;

            }
            else if (jntLimParts[10].max >= 45) {
                jntLimParts[10].max = 44;
            }
            jntLimParts[10].min = jntLimParts[10].max - 1;
        }

        float waistDir = 0;
        int waistVec = (int)vectorAction[8];
        switch (footLVec) {
            case 1:
                waistDir = 0;
                break;
            case 2:
                waistDir = bodyMoveSensitivity;
                break;
            case 3:
                waistDir = -bodyMoveSensitivity;
                break;
        }
        bodyParts[0].transform.Rotate(0, waistDir, 0);




        sensor.transform.eulerAngles = new Vector3(0, 0, 0);

        if ( wFront.transform.position.y < wBack.transform.position.y-1 || wFront.transform.position.y > wBack.transform.position.y + 1 || buttR.transform.position.y < buttL.transform.position.y - 1 || buttR.transform.position.y > buttL.transform.position.y + 1) {                //Maintain waist rotation.
            AddReward(-.2f);
        }
        else {
            AddReward(.01f);
        }

        if (waist.transform.position.y <= -3) {             //Maintain waist height.
            AddReward(-.2f);
            Done();
        }
        else {
            AddReward(.01f);
        }

        if(waist.transform.position.x > posStart[0].x + 2 || waist.transform.position.x < posStart[0].x - 2 || waist.transform.position.z > posStart[0].z + 2 || waist.transform.position.z < posStart[0].z - 2) {              //Maintain waist position.
            AddReward(-.2f);
        }
        else {
            AddReward(.01f);
        }
    }

    public override void CollectObservations() {

        for (int i = 0; i < bodyParts.Length; i++) {
            AddVectorObs(bodyParts[i].transform.position);
            AddVectorObs(bodyParts[i].transform.rotation);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
            AddVectorObs(jntLimParts[i].max);
            AddVectorObs(jntLimParts[i].min);
            AddVectorObs(wFront.transform.position.y);
            AddVectorObs(wFront.transform.rotation);
            AddVectorObs(wBack.transform.position.y);
            AddVectorObs(wBack.transform.rotation);

        }
    }
}

1 个答案:

答案 0 :(得分:0)

不幸的是,为了解决此问题,我不得不从头开始。完成此操作后,请确保我拥有 MLAgent的更新版本,一切都再次正常。