ML-Agents代理未重置?

时间:2019-12-12 16:55:53

标签: c# unity3d machine-learning game-physics ml-agent

我一直在做自平衡的一条腿。如果他的“腰围”低于某个y位置值(跌落/绊倒),则该区域将重置并从其奖励分数中扣除积分。我是机器学习的新手,所以轻松一点! 为什么座席跌倒时座席不重置

Legs trainer resport Agents in inspector




代理代码(已更新):

    using MLAgents;
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;

    using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    //public GameObject goal;

    // private float buttR = 0f;

    public GameObject[] bodyParts = new GameObject[9];
    public Vector3[] posStart = new Vector3[9];
    public Vector3[] eulerStart = new Vector3[9];



    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        bodyParts = new GameObject[]{waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL};

        for(int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
        }

    }

    public override void AgentReset() {
        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);

        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            print("reset!");
            AddReward(-.1f);
            Done();
        }

        public override void CollectObservations() {
            AddVectorObs(waist.transform.localEulerAngles.y);
            AddVectorObs(buttR.transform.localEulerAngles.x);
            AddVectorObs(buttL.transform.localEulerAngles.x);
            AddVectorObs(thighR.transform.localEulerAngles.y);
            AddVectorObs(thighL.transform.localEulerAngles.y);
            AddVectorObs(legR.transform.localEulerAngles.y);
            AddVectorObs(legL.transform.localEulerAngles.y);
            AddVectorObs(footR.transform.localEulerAngles.y);
            AddVectorObs(footL.transform.localEulerAngles.y);
            AddVectorObs(waist.transform.position);
        }
    }




区域代码:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

public class BalancingArea : Area
{
    public List<BalanceAgent> BalanceAgent { get; private set; }
    public BalanceAcademy BalanceAcademy { get; private set; }
    public GameObject area;

    private void Awake() {
        BalanceAgent = transform.GetComponentsInChildren<BalanceAgent>().ToList();              //Grabs all agents in area
        BalanceAcademy = FindObjectOfType<BalanceAcademy>();                //Grabs balance acedem
    }

    private void Start() {

    }

    public void ResetAgentPosition(BalanceAgent agent) {
        agent.transform.position = new Vector3(area.transform.position.x, 0, area.transform.position.z);
        agent.transform.eulerAngles = new Vector3(0,0,0);
    }

    // Update is called once per frame
    void Update()
    {

    }
}




BalanceAcademy的代码:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAcademy : Academy
{

}



用于运行培训师的命令:

mlagents-learn config/trainer_config.yaml --run-id=balancetest09 --train

1 个答案:

答案 0 :(得分:1)

摘自creating a new environment上的文档:

  

初始化和重置代理

     

Agent达到其目标时,将其标记为完成并且其Agent   重置功能将目标移动到随机位置。另外,如果   代理从平台滚下,重置功能将其放回   地板。

     

要移动目标GameObject,我们需要引用其Transform   (将GameObject的位置,方向和比例存储在3D中   世界)。要获得此参考,请向其中添加类型Transform的公共字段   RollerAgent类。 Unity中组件的公共字段   显示在“检查器”窗口中,使您可以选择   在Unity编辑器中用作目标的GameObject。

     

要重置特工的速度(稍后再施加力以移动   代理),我们需要对刚体组件的引用。刚体是   Unity物理模拟的主要元素。 (有关详细信息,请参见物理   (因为是刚体组件)   与我们的特工脚本相同的GameObject,这是获得此效果的最佳方法   参考使用GameObject.GetComponent<T>(),我们可以调用   我们脚本的Start()方法。

     

到目前为止,我们的RollerAgent脚本如下:

using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class RollerAgent : Agent
{
    Rigidbody rBody;
    void Start () {
        rBody = GetComponent<Rigidbody>();
    }

    public Transform Target;
    public override void AgentReset()
    {
        if (this.transform.position.y < 0)
        {
            // If the Agent fell, zero its momentum
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.position = new Vector3( 0, 0.5f, 0);
        }

        // Move the target to a new spot
        Target.position = new Vector3(Random.value * 8 - 4,
                                      0.5f,
                                      Random.value * 8 - 4);
    }
}

因此,您应该重写AgentReset方法,以便重新设置座席关节的位置。要开始使用,可以在InitializeAgent中进行每个关节的旋转和位置,然后在AgentReset中进行恢复。另外,将刚体的速度和角速度归零。

我在文档或示例中看不到有关在Done中调用Update的任何内容,因此可能建议甚至要求将其放在AgentAction中以使其表现正常。最好将所有内容从Update移到AgentAction

此外,您可能希望在具有3个分量(xyz)的特征向量中使用transform.localEulerAngles,而不是具有4个分量(xyzw)的transform.localRotation。否则,您不应忽略localRotation的w分量。

总共看起来可能像这样:

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;
    public GameObject goal;

    private List<GameObject> gameObjectsToReset;
    private List<Rigidbody> rigidbodiesToReset;
    private List<Vector3> initEulers;
    private List<Vector3> initPositions;

    // private float buttR = 0f;


    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        gameObjectsToReset= new List<GameObject>(new GameObject[]{
                waist, buttR, buttL, thighR, thighL, legR, legL,
                footR, footL});
        rigidbodiesToReset = new List<Rigidbody>();
        initEulers = new List<Vector3>();
        initPositions = new List<Vector3>();

        foreach (GameObject g in gameObjectsToReset) {
            rigidbodiesToReset.Add(g.GetComponent<Rigidbody>());
            initEulers.Add(g.transform.eulerAngles);
            initPositions.Add(g.transform.position);
        }
    }

    public override void AgentReset() {
        for (int i = 0 ; i < gameObjectsToReset.Count ; i++) {
            Transform t = gameObjectsToReset[i].transform;
            t.position = initPositions[i];
            t.eulerAngles = initEulers[i];

            Rigidbody r = rigidbodiesToReset[i];
            r.velocity = Vector3.zero;
            r.angularVelocity = Vector3.zero;
        } 
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);



        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1.3) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            Done();
            AddReward(-.1f);
        }
    }

    public override void CollectObservations() {
        AddVectorObs(waist.transform.localEulerAngles.y);
        AddVectorObs(buttR.transform.localEulerAngles.x);
        AddVectorObs(buttL.transform.localEulerAngles.x);
        AddVectorObs(thighR.transform.localEulerAngles.y);
        AddVectorObs(thighL.transform.localEulerAngles.y);
        AddVectorObs(legR.transform.localEulerAngles.y);
        AddVectorObs(legL.transform.localEulerAngles.y);
        AddVectorObs(footR.transform.localEulerAngles.y);
        AddVectorObs(footL.transform.localEulerAngles.y);

        AddVectorObs(waist.GetComponent<Rigidbody>().freezeRotation);

        AddVectorObs(waist.transform.position);
    }
}

最后,确保将BalanceAgent的Max Step设置为足够大的值,以查看代理是否会失败,对于初学者来说可能是500或1000。

<code>Max Step</code> is editable in the inspector