神经网络学习算法

时间:2015-05-26 01:18:09

标签: java algorithm machine-learning neural-network

我一直致力于制作一个有目标的神经网络,击中一个移动的目标,它有基于从两个轴到射手到目标的距离,射手的旋转和风速的输入。

每5秒钟射手重置到屏幕上的不同位置。

我试图使用反向道具算法,通过根据子弹击中目标的距离计算误差,如果射击者在5秒窗口内没有射击,则也会传播错误

我的网络似乎没有正确学习,我想知道是否有人能指出我正确的方向。

这是我的代码:

    import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.util.Random;
import java.util.stream.DoubleStream;

import javax.swing.JFrame;


public class ClassMain extends JFrame{
    int ShooterLocationX = 100;
    int ShooterLocationY = 150;
    int shooterDiameter = 35;
    int targetX = 525;
    int targetY = 100;
    int targetWidth = 15;
    int targetLength = 50;
    double error;
    double targetVectorX=ShooterLocationX + 100;
    double targetVectorY = ShooterLocationY+shooterDiameter/2;
    boolean direction = false;
    boolean ShotFired = false;
    boolean rotateDirection;
    boolean start =true;
    boolean shotbeenFires;
    double ShooterRotation = 0;
    double[] outputs1 = new double[4];
    double[] outputs2 = new double[8];
    double[] outputs3 = new double[8];
    double[] outputs4 = new double[3];
    Percepton[] input = new Percepton[4];
    Percepton[] hidden = new Percepton[8];
    Percepton[] hidden2 = new Percepton[8];
    Percepton[] output = new Percepton[3];
    Shooter shot = new Shooter(ShooterRotation);
    Random random = new Random();
    long StartTime = 0;
    long CurrTime = 0;
    public static void main(String []args){
        ClassMain CM= new ClassMain();
        CM.setup();
    }
    public void setup(){
        setSize(900,300);
        setVisible(true);
        setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        setResizable(false);
        createBufferStrategy(2);
        repaint();
        for(int i = 0; i < 4; i ++){
            input[i] = new Percepton();
            input[i].setInputNumber(1);
        }
        for(int i =0;i < 8;i++){
            hidden[i] = new Percepton();
            hidden[i].setInputNumber(input.length);
        }
        for(int i =0;i < 8;i++){
            hidden2[i] = new Percepton();
            hidden2[i].setInputNumber(input.length);
        }
        for(int i = 0;i < 3; i++){
            output[i] = new Percepton();
            output[i].setInputNumber(hidden.length);
        }
    }

    public void run(){
        double[] temp = new double[10];
        if(start){
            StartTime = System.currentTimeMillis();
            start = false;
        }
        CurrTime = System.currentTimeMillis();
        if(CurrTime-StartTime > 5000){
            ShooterLocationX=random.nextInt(400);
            ShooterLocationY=random.nextInt(275);
            shot.newX = ShooterLocationX+(shooterDiameter/2)-5;
            shot.newY = ShooterLocationY+(shooterDiameter/2)-5;
            shot.windSpeed = (float)(random.nextInt(200)-100)/100;
            start = true;
            if(!shotbeenFires){
                error+=15;
            }
                for(int n = 0; n < output.length;n++){
                    output[n].learn(n, error, output[n].weights[n]);
                }
                for(int n = 0; n < hidden2.length;n++){
                    hidden2[n].learn(n, error, hidden2[n].weights[n]);
                }
                for(int n = 0; n < hidden.length;n++){
                    hidden[n].learn(n, error, hidden[n].weights[n]);
                }
                for(int n = 0; n < input.length;n++){
                    input[n].learn(n, error, input[n].weights[n]);
                }
                error = 0;
                shotbeenFires = false;
        }


        outputs1[0] = input[0].sigmoid((((Math.max(ShooterLocationX,targetX)-Math.min(ShooterLocationX,targetX))/300)*2)*input[0].weights[0],1*input[0].biasWeight);
        outputs1[1] = input[1].sigmoid((((Math.max(ShooterLocationY,targetY)-Math.min(ShooterLocationY,targetY))/150)*2)*input[1].weights[0],1*input[1].biasWeight);
        outputs1[2] = input[2].sigmoid((ShooterRotation)*input[2].weights[0],1*input[2].biasWeight);
        outputs1[3] = input[3].sigmoid((shot.windSpeed)*input[3].weights[0],1*input[3].biasWeight);

        System.out.println("\n");
        for(int n = 0;n < hidden.length;n++){
            for(int i = 0;i<temp.length;i++){
                temp[i] = 0;
            }
            for(int i = 0;i < outputs1.length;i++){
                temp[i] = outputs1[i]*hidden[n].weights[i];

            }

            outputs2[n] = (float)DoubleStream.of(temp).sum();
            outputs2[n] = hidden[1].sigmoid(outputs2[n],1*hidden[n].biasWeight);
            //System.out.print(outputs2[n]);
        }
        System.out.println("\n");
        for(int n = 0;n < hidden2.length;n++){
            for(int i = 0;i<temp.length;i++){
                temp[i] = 0;
            }
            for(int i = 0;i < outputs2.length;i++){
                temp[i] = outputs2[i]*hidden2[n].weights[i];

            }
            outputs3[n] = (float)DoubleStream.of(temp).sum();
            outputs3[n] = hidden2[1].sigmoid(outputs3[n],1*hidden2[n].biasWeight);
            //System.out.print(outputs3[n]);
        }

        System.out.println("\n");
        for(int n = 0;n < output.length;n++){
            for(int i = 0;i<temp.length;i++){
                temp[i] = 0;
            }


            for(int i = 0;i < outputs1.length;i++){
                    temp[i] = outputs3[n]*hidden[n].weights[i];

                }
                outputs4[n] = (float)DoubleStream.of(temp).sum();
                outputs4[n] = output[n].sigmoid(outputs4[n],1*output[n].biasWeight);
                //System.out.print(outputs4[n]);
            }
            if(!ShotFired){
                System.out.println("\n");
                if(outputs4[0] > 0.5){
                    ShotFired = true;
                    shot.rotation = ShooterRotation;
                }
                if (outputs4[1] > 0.5){
                    ShooterRotation = ShooterRotation + outputs3[1];
                }
                if(outputs4[2] > 0.5){
                    ShooterRotation =ShooterRotation - outputs3[2];
                }
                if(ShooterRotation*360 > 360){
                    ShooterRotation = ((ShooterRotation * 360) + 360) /360;
                }else if(ShooterRotation * 360 < 0){
                    ShooterRotation = ((ShooterRotation*360)-360) /360;
                }
            }


            for(int i = 0; i < 1; i++){
                if(ShotFired){
                    shotbeenFires = true;
                    shot.newLocationX(shot.newX);
                    shot.newLocationY(shot.newY);
                }
                if(targetY < 20 ||targetY+targetLength >300){
                    direction = !direction;
                }
                if(direction){
                    targetY += 8;
                }else{
                    targetY -= 8;
                }
                if(shot.newX + (15)>= targetX && shot.newX + 15 < targetX +targetWidth && shot.newY + 15> targetY && shot.newY <targetY + targetLength){
                    System.out.println("Target Hit");
                    error -= 10; 
                    shot.newX = ShooterLocationX + shooterDiameter/2-5;
                    shot.newY = ShooterLocationY + shooterDiameter/2-5;
                    ShotFired = false;
                }else if(shot.newX > 600|| shot.newX<10||shot.newY<10||shot.newY> 290){
                    System.out.println("MISS");
                    error += (shot.newY - targetY);
                    shot.newX = ShooterLocationX + shooterDiameter/2-5;
                    shot.newY = ShooterLocationY + shooterDiameter/2-5;
                    ShotFired = false;
                }

            }

            repaint();
        }

        public void paint(Graphics g){


            targetVectorX = ShooterLocationX + shooterDiameter/2 + Math.cos(Math.toRadians(ShooterRotation)) * 100;
                targetVectorY = ShooterLocationY + shooterDiameter/2 + Math.sin(Math.toRadians(ShooterRotation)) * 100;
                super.paint(g);
                Graphics2D g2 = (Graphics2D)g;
                g2.fillOval(ShooterLocationX, ShooterLocationY, shooterDiameter, shooterDiameter);
                g2.setPaint(Color.RED);
                g2.drawLine(ShooterLocationX + shooterDiameter/2, ShooterLocationY+shooterDiameter/2, (int)targetVectorX, (int)targetVectorY);
                g2.setPaint(Color.BLUE);
                g2.fillRect(targetX, targetY, targetWidth, targetLength);
                g2.setPaint(Color.RED);
                g2.fillOval((int)shot.newX, (int)shot.newY, 10, 10);
                g2.setPaint(Color.gray);
                g2.fillRect(600, 0, 300,300);
                for(int i = 0; i < input.length;i++){
                    if(outputs1[i] > 0.5){
                        g2.setPaint(Color.red);
                    }else{
                        g2.setPaint(Color.white);
                    }
                    g2.drawOval(625+i*35, 250, 25, 25);
                    for(int j = 0;j<hidden.length;j++){
                        g2.drawLine(640+i*35, 250, 615+j*30, 195);
                    }
                }
                for(int i = 0; i < hidden.length;i++){
                    if(outputs2[i] > 0.5){
                        g2.setPaint(Color.red);
                    }else{
                        g2.setPaint(Color.white);
                    }
                    g2.drawOval(605+i*30,  175, 20, 20);
                    for(int j = 0;j<hidden2.length;j++){
                        g2.drawLine(615+i*30, 175, 615+j*30, 145);
                    }
                }
                for(int i = 0; i < hidden2.length;i++){
                    if(outputs3[i] > 0.5){
                        g2.setPaint(Color.red);
                    }else{
                        g2.setPaint(Color.white);
                    }
                    g2.drawOval(605+i*30, 125, 20, 20);
                    for(int j = 0;j<output.length;j++){
                        g2.drawLine(615+i*30, 125, 697+j*50, 75);
                    }
                }
                for(int i = 0; i < output.length;i++){
                    if(outputs4[i] > 0.5){
                        g2.setPaint(Color.red);
                    }else{
                        g2.setPaint(Color.white);
                    }
                    g2.drawOval(685+i*50, 50, 25, 25);

                }
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
                run();

            }
        }


    public class Shooter {
        double windSpeed = 0.5;
        int distance = 0;
        double newX;
        double newY;
        double rotation;
        double temp = 0;
        public Shooter(double rotate){
            rotation = rotate*360;
            windSpeed = windSpeed*5;
        }
        public void newLocationX(double XPrev){
            newX = XPrev + Math.cos(Math.toRadians(rotation)) * 20;
        }

        public void newLocationY(double YPrev){
            newY = YPrev + Math.sin(Math.toRadians(rotation)) * 20;
                newY =newY - windSpeed;
        }

    }

    import java.util.Random;


public class Percepton {
    int numOfInputs = 100;
    float[] weights = new float[numOfInputs];
    int sum;
    double learningRate = 0.1;
    float biasWeight;
    float a;
    float out;
    Random random = new Random();
    public void setWeight(int index,float weight){
        weights[index] = weight;
    }
    public void setInputNumber(int inputs){
        numOfInputs = inputs;
        for(int i = 0; i < inputs; i++){
            weights[i] = (float)(random.nextInt(400)-200)/200;
            biasWeight = (float)(random.nextInt(400)-200)/200;
            //System.out.println(weights[i]);

        }
    }

    public void learn(int index, double error, float value){
        weights[index] += (float)learningRate * (float)error * value;
    }

    public double sigmoid(double input,double bias){
        return  (float) (1/(1+Math.exp(-input*bias)));
    }
}

0 个答案:

没有答案