我的蛇AI无法正常工作,我也不知道为什么

时间:2019-07-18 17:00:57

标签: javascript tensorflow machine-learning

在我的代码中,我正在使用张量流制作一个神经网络,并使它学会玩蛇的游戏。

首先,我通过使蛇向10,000个游戏方向随机移动来生成训练数据,并收集每次移动的数据。在我的NN中,我输入: -如果周围有障碍物1个街区 -与食物的角度 -建议的方向

然后产生一个输出,该输出可以是任何实数的数字。

在NN经过训练之后,我通过使其玩游戏来对其进行测试。在这个阶段,我给它所有相同的输入,但是在建议的方向上有所不同。然后,无论哪个建议的方向得分最高,蛇都会朝该方向移动。

我正在模仿这个项目,但是使用javascript: Project

我基本上已经复制了几乎所有特定方面,例如, -NN结构和激活以及优化器和损失 -要提供什么输入和输出

我将提供以下代码。我提供的代码是蛇对象中发生的一切的全部内容。我不包括在内,因为我认为这不是问题。我还将包括我的AI脚本,但我也不认为问题出在该代码中。

问题: 我无法获得我要模仿的项目中看到的结果。蛇也总是死亡。有时,当蛇似乎正朝着食物前进时,它只是在食物周围盘旋而没有实际食用。奇怪的是,我训练的次数越多,情况就越糟。另外,我对时间有疑问。要创建10,000个游戏,我的计算机需要1-2分钟,而训练我的计算机则需要每个时间3-4分钟。考虑我正在使用浏览器和tensorflow.js,这是否花了太多时间?

我认为问题出在我为每个动作分配分数或给定输入中的某个地方。我觉得这些还不够,或者有一些错误,这使这条蛇无法学习我想要的学习方式。

主要游戏循环的代码:

import * as snake from "./snake.js";
import * as AI from "/ai.js";

//Canvas and document information
const c = document.getElementById("can");
const ctx = c.getContext("2d");
const width = can.width;
const height = can.height;
const state = document.getElementById("state");
const score = document.getElementById("score");
const restart1 = document.getElementById("restart");
let model;
let numOfNoSurvive = 0;
let numOfWrongDir= 0;
let numOfRightDir = 0;

//Global constants:
let end = false;
const rez = 20;

//Features to keep track of:
let obsleft
let obsright
let obsfront

//utility functions:
function indexOfMax(arr) {
    if (arr.length === 0) {
        return -1;
    }
    var max = arr[0];
    var maxIndex = 0;
    for (var i = 1; i < arr.length; i++) {
        if (arr[i] > max) {
            maxIndex = i;
            max = arr[i];
        }
    }

    return maxIndex;
}

function randomIntFromInterval(min, max) // min and max included
{
    return Math.floor(Math.random() * (max - min + 1) + min);
}

//restart function:
function restart() {
    snake1.body = [];
    snake1.body[0] = [0, 0];
    snake1.dirx = rez;
    snake1.diry = 0;
    snake1.len = 1;
    end = false;
    state.innerHTML = "Have Fun!";
    score.innerHTML = "0";
    food1.x = Math.floor((Math.random() * (width)) / rez) * rez;
    food1.y = Math.floor((Math.random() * (height)) / rez) * rez;

}

restart1.onclick = restart;

let dir;

//Input handling
function press(evt) {
    if (evt.key == "w") {
        if (snake1.diry != rez) {
            snake1.diry = -rez;
            snake1.dirx = 0;
            dir = "u";
        }
    } else if (evt.key == "s") {
        if (snake1.diry != -rez) {
            snake1.diry = rez;
            snake1.dirx = 0;
            dir = "d";
        }
    } else if (evt.key == "a") {
        if (snake1.dirx != rez) {
            snake1.diry = 0;
            snake1.dirx = -rez;
            dir = "l";
        }
    } else if (evt.key == "d") {
        if (snake1.dirx != -rez) {
            snake1.diry = 0;
            snake1.dirx = rez;
            dir = "r";
        }
    }

}
//document.addEventListener("keydown", press, false);

//Food class:
class Food {
    constructor() {
        this.x = Math.floor((Math.random() * (width)) / rez) * rez;
        this.y = Math.floor((Math.random() * (height)) / rez) * rez;
        console.log(this.x, this.y);

        this.width = rez;
        this.height = rez;
    }

    update() {
        for (let i = 0; i < snake1.body.length; i++) {
            if (snake1.body[i][0] == this.x && snake1.body[i][1] == this.y || end == true) {
                this.x = Math.floor((Math.random() * (width)) / rez) * rez;
                this.y = Math.floor((Math.random() * (height)) / rez) * rez;
                snake1.grow();
                score.innerHTML = (parseInt(score.innerHTML) + 1).toString();
                return true;
            }
        }

        return false;
    }

    show() {
        ctx.fillStyle = "red";
        ctx.fillRect(this.x, this.y, this.width, this.height);
    }

}


//Creating the objects in the game:
let snake1 = new snake.Snake();
let food1 = new Food();



function update() {
    let score;

    if (snake1.endGame()) {
        state.innerHTML = "Game Over"
        end = true;
        score = -1;
    } else if (!end) {
        let head = snake1.body[snake1.body.length - 1];
        let hx = head[0];
        let hy = head[1];
        let fx = food1.x;
        let fy = food1.y;
        let eat = false;
        let prevlen = snake1.len;
        let distB = Math.sqrt((Math.pow((fx-hx), 2)) + (Math.pow(fy-hy,2)));
        snake1.update();
        if(food1.update()){
            eat = true;
        }
        head = snake1.body[snake1.body.length - 1];
        hx = head[0];
        hy = head[1];
        fx = food1.x;
        fy = food1.y;
        let newlen = snake1.len;
        let distA = Math.sqrt((Math.pow((fx-hx), 2)) + (Math.pow(fy-hy,2)));
        console.log(hx, hy, fx, fy, distA, distB);
        if(distA < distB || newlen > prevlen){
            score = 1;
        }else{
            score = 0;
        }
    }

    return score;
}

function draw() {
    if (end) {
        ctx.fillStyle = "black";
        ctx.fillRect(0, 0, width, height);
    } else {
        ctx.fillStyle = "white";
        ctx.fillRect(0, 0, width, height);
        snake1.show();
        food1.show();
    }
}

function findingLeftRightFront() {
    //Get the features and information around the snake
    let features = [0, 0, 0];
    let head = snake1.body[snake1.body.length - 1];
    let hr;
    let hl;
    let hf;
    //set direction ****
    if (snake1.dirx == rez && snake1.diry == 0) {
        dir = "r";
    }
    if (snake1.dirx == -rez && snake1.diry == 0) {
        dir = "l";
    }
    if (snake1.dirx == 0 && snake1.diry == rez) {
        dir = "d";
    }
    if (snake1.dirx == 0 && snake1.diry == -rez) {
        dir = "u";
    }

    //Find left right and forward based on direction
    switch (dir) {
        case "u":
            hr = [head[0] + rez, head[1]];
            hl = [head[0] - rez, head[1]];
            hf = [head[0], head[1] - rez];
            break;
        case "d":
            hr = [head[0] - rez, head[1]];
            hl = [head[0] + rez, head[1]];
            hf = [head[0], head[1] + rez];
            break;
        case "l":
            hr = [head[0], head[1] - rez];
            hl = [head[0], head[1] + rez];
            hf = [head[0] - rez, head[1]];
            break;
        case "r":
            hr = [head[0], head[1] + rez];
            hl = [head[0], head[1] - rez];
            hf = [head[0] + rez, head[1]];
            break;
    }

    //Body part
    //right, left, forward
    for (let i = 0; i < snake1.body.length - 1; i++) {
        let part = snake1.body[i];
        if (part[0] == hr[0] && part[1] == hr[1]) {
            features[0] = 1;
        }
        if (part[0] == hl[0] && part[1] == hl[1]) {
            features[1] = 1;
        }
        if (part[0] == hf[0] && part[1] == hf[1]) {
            features[2] = 1;
        }
    }

    //Boundaries
    if (hr[0] >= width || hr[0] < 0 || hr[1] >= height || hr[1] < 0) {
        features[0] = 1;
    }
    if (hl[0] >= width || hl[0] < 0 || hl[1] >= height || hl[1] < 0) {
        features[1] = 1;
    }
    if (hf[0] >= width || hf[0] < 0 || hf[1] >= height || hf[1] < 0) {
        features[2] = 1;
    }

    return features;
}

function findAngleToFood() {
    let head = snake1.body[snake1.body.length - 1];
    let hx = head[0];
    let hy = head[1];
    let fx = food1.x;
    let fy = food1.y;
    let dirx = snake1.dirx;
    let diry = snake1.diry;

    //get two vectors a and b: a - unit vector direction of snake, b - unit vector direction from head to food
    let a = [(dirx+hx) - hx, (diry+hy) - hy];
    let b = [fx - hx, fy - hy];



    let dista = Math.sqrt(Math.pow(a[0], 2) + Math.pow(a[1], 2));
    let distb = Math.sqrt(Math.pow(b[0], 2) + Math.pow(b[1], 2));


    a = [(a[0]/ dista), (a[1]/dista)];
    b = [(b[0]/ distb), (b[1]/distb)];

    dista = Math.sqrt(Math.pow(a[0], 2) + Math.pow(a[1], 2));
    distb = Math.sqrt(Math.pow(b[0], 2) + Math.pow(b[1], 2));

    //cos0 = x/y

    let x = (a[0] * b[0]) + (a[1] * b[1]);
    let y = dista * distb;


    let angle = Math.acos(x / y) / (Math.PI);
    if(Number.isNaN(angle)){
        console.log("CHECK");
        angle = 0;
    }

    console.log(angle);

    return angle;
}

function moveSnakeOnValue(val) {
    if (val == 0) {
        console.log("right");
        switch (dir) {
            case "u":
                snake1.diry = 0;
                snake1.dirx = rez;
                break;
            case "d":
                snake1.diry = 0;
                snake1.dirx = -rez;
                break;
            case "l":
                snake1.diry = -rez;
                snake1.dirx = 0;
                break;
            case "r":
                snake1.diry = rez;
                snake1.dirx = 0;
                break;
        }
    }
    if (val == 1) {
        console.log("left");
        switch (dir) {
            case "u":
                snake1.diry = 0;
                snake1.dirx = -rez;
                break;
            case "d":
                snake1.diry = 0;
                snake1.dirx = rez;
                break;
            case "l":
                snake1.diry = rez;
                snake1.dirx = 0;
                break;
            case "r":
                snake1.diry = -rez;
                snake1.dirx = 0;
                break;
        }
    }
    if (val == 2) {
        console.log("forward");
        switch (dir) {
            case "u":
                snake1.diry = -rez;
                snake1.dirx = 0;
                break;
            case "d":
                snake1.diry = rez;
                snake1.dirx = 0;
                break;
            case "l":
                snake1.diry = 0;
                snake1.dirx = -rez;
                break;
            case "r":
                snake1.diry = 0;
                snake1.dirx = rez;
                break;
        }
    }
}

function gameloop() {
    if (!end) {
        //fix after generating data
        let featuresr = findingLeftRightFront();
        featuresr.push(findAngleToFood());
        featuresr.push(0);


        let featuresl = findingLeftRightFront();
        featuresl.push(findAngleToFood());
        featuresl.push(1);

        let featuresf = findingLeftRightFront();
        featuresf.push(findAngleToFood());
        featuresf.push(2);

        let outputr1 = AI.predict(tf.tensor2d([featuresr]));


        let outputl1 = AI.predict(tf.tensor2d([featuresl]));


        let outputf1 = AI.predict(tf.tensor2d([featuresf]));

        let outputr = outputr1.dataSync();
        let outputl = outputl1.dataSync();
        let outputf = outputf1.dataSync();

        let outputs = [outputr[0], outputl[0], outputf[0]];

        console.log(featuresr, featuresl, featuresf);
        console.log(outputs);
        let i = indexOfMax(outputs);
        console.log(i);

        console.log(snake1.body[0]);
        moveSnakeOnValue(i);


        //find max from output and set snake direction
        let score = update();
        draw();
    } else {
        food1.update();
        restart();
    }
}

export async function generateTrainingData() {
    let dataToTrain = [];
    let inputdata = [];
    let outputdata = [];

    for (let i = 0; i < 10000; i++) {
        while (!end) {
            //generate input data [right, left, front, angle to food, suggested direction]
            let score = update();
            let f1 = findingLeftRightFront();
            f1.push(findAngleToFood());
            let dir = randomIntFromInterval(0, 2);
            f1.push(dir);

            //0 - right, 1 - left, 2- forward

            //move the snake using the suggested direction and get output
            moveSnakeOnValue(dir);
            if(score == -1){
                numOfNoSurvive++;
            }else if(score == 0){
                numOfWrongDir++;
            }else if(score == 1){
                numOfRightDir++;
            }

            //update data
            inputdata.push(f1);
            outputdata.push([score]);
        }
        restart();
    }

    dataToTrain.push(inputdata, outputdata);
    console.log(dataToTrain);
    return dataToTrain;
}

async function first() {
    //order of runtime
    //model = await tf.loadLayersModel('/My NN model/my-amazingmodel.json');
    let inputs;
    let outputs;

    await generateTrainingData().then((data) => {
        console.log(data[0]);
        console.log(data[1]);
        inputs = tf.tensor2d(data[0]);
        outputs = tf.tensor2d(data[1]);
    });
    console.log(numOfNoSurvive,numOfWrongDir,numOfRightDir);
    console.log(tf.memory().numTensors);
    await AI.train(inputs, outputs);
    window.setInterval(gameloop, 1000/8);
    //await AI.model.save('downloads://my-amazingmodel');
}



first();

AI脚本的代码(使用张量flow.js):

import * as game from "/game.js";

//Canvas and document information
const c = document.getElementById("can");
const ctx = c.getContext("2d");
const width = can.width;
const height = can.height;
const learningrate = 0.01;

/*
Features:
    - The obstacle on the right of the snake head
    - The obstacle on the left of the snake head
    - The obstacle in front of the snake head
    - Suggested direction
    So far, these features keep the snake from dying

    - The normalized angle between the snake head and the food
    This makes the score go up
*/

//Making the model:

export const model = tf.sequential();
model.add(tf.layers.dense({ inputShape: [5], units: 25, activation: 'relu' })); // Hidden
model.add(tf.layers.dense({ units: 1, activation: 'linear' })); // Output
const sgdOpt = tf.train.sgd(learningrate);
const config = {
    optimizer: "adam",
    loss: "meanSquaredError"
}
model.compile(config);

export async function train(inputs, outputs) {
    console.log(inputs.print(), outputs.print());
    let history = await model.fit(inputs, outputs, {
        batchSize: 32,
        epochs: 3,
        shuffle: true,
        callbacks: {onEpochEnd: (epoch, logs) => console.log("epoch:", epoch, logs.loss)}
      });

    inputs.dispose();
    outputs.dispose();
    console.log(history.history.loss[0]);
    console.log("Num: ", tf.memory().numTensors);
    console.log("training completed");
}


export function predict(input) {
    return model.predict(input);
}

我已经为此工作了几天,并感到非常沮丧,因为我似乎无法找到导致蛇AI混乱的错误。

在项目链接中,作者还提供了指向他自己的python代码的链接。我确保所做的一切在概念上都与他的代码相似。但是结果仍然不一样。

我希望这是一个有效的问题,并感谢您阅读。

** edit,我认为如果我添加其余代码,那会更好,因为那里可能有问题。

代码的蛇形部分:

//Canvas and document information
const c = document.getElementById("can");
const ctx = c.getContext("2d");
const width = can.width;
const height = can.height;
const rez = 20;

export class Snake {
    constructor() {
        this.body = [];
        this.body[0] = [0, 0];
        this.pressUp = false;
        this.pressDown = false;
        this.pressLeft = false;
        this.pressRight = false;
        this.dirx = rez;
        this.diry = 0;
        this.len = 1;
    }

    update() {
        let head = [...this.body[this.body.length - 1]];
        this.body.shift();
        head[0] += this.dirx;
        head[1] += this.diry;
        this.body.push(head);
    }

    endGame() {
        let x = this.body[this.body.length - 1][0];
        let y = this.body[this.body.length - 1][1];

        if (x >= width || x < 0 || y >= height || y < 0) {
            console.log("beep");
            return true;
        }

        for (let i = 0; i < this.body.length - 1; i++) {
            let part = this.body[i];
            if (part[0] == x && part[1] == y) {
                return true;

            }
        }
        return false;
    }

    show() {
        for (let i = 0; i < this.body.length; i++) {
            ctx.fillStyle = "black";
            ctx.fillRect(this.body[i][0], this.body[i][1], rez, rez);
        }
    }

    grow() {
        let head = [...this.body[this.body.length - 1]];
        this.len++;
        head[0] += this.dirx;
        head[1] += this.diry;
        this.body.push(head);
    }

}

**另一项编辑(抱歉)-如果您想在自己的浏览器中运行它,请执行以下操作: HTML:

<!DOCTYPE html>
<html lang="en" dir="ltr">
  <head>
    <meta charset="utf-8">
    <title>AI</title>
    <link rel="stylesheet" type="text/css" href="style.css">
  </head>
  <body>

    <canvas id="can" height="500" width="500" ></canvas> 
    <button id="restart">Restart</button>
    <p id="state">Have Fun!</p>
    <p id="score">0</p>
    <script src="snake.js"  type="module"></script>
    <script src="game.js"  type="module"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>

  </body>

</html>

CSS:

body {
    margin:0;
}
canvas{
    display:block;
    border: 3px solid grey;
}
.disable-selection {
    -moz-user-select: none; /* Firefox */
     -ms-user-select: none; /* Internet Explorer */
  -khtml-user-select: none; /* KHTML browsers (e.g. Konqueror) */
 -webkit-user-select: none; /* Chrome, Safari, and Opera */
 -webkit-touch-callout: none; /* Disable Android and iOS callouts*/
}

问题总结:

如何标记蛇的输入,以便获得良好的训练数据?我在项目中尝试了一个,但是训练得不好。而且,如果您认为我的标签是正确的,那么代码中是否存在错误,无法使蛇正确学习? <-我在训练和标签方面找不到python代码和javascript之间的任何区别。数字上是否有一些细微调整,以使其学习更好(例如:学习率)。

目标:训练后,snake AI必须能够证明它可以可靠地从10,000场比赛中获得至少10分作为训练数据。该项目中的人获得的收益超过100。

更改和修改:

-同样,它在代码中,但我没有明确说明。我感觉到了我如何奖励行动。到目前为止,模仿该项目,-1尚不存在。 1是朝正确的方向前进。而0的方向错误。这些数字是每个数据的“预期输出”或“标签”

  • 我添加了蛇形部分,如果有问题可能导致代码出现故障

  • 我添加了HTML和CSS,基本上我的所有项目现在都在那里

  • 我再次修复了代码。现在,有关该问题的所有代码均已完全正常运行。输入形状有误。如果您将index.html和所有js文件上传到同一文件夹中。您应该看到它可以正常工作而没有解决问题。 (蛇无法正常学习)。

0 个答案:

没有答案