我一直致力于制作一个有目标的神经网络,击中一个移动的目标,它有基于从两个轴到射手到目标的距离,射手的旋转和风速的输入。
每5秒钟射手重置到屏幕上的不同位置。
我试图使用反向道具算法,通过根据子弹击中目标的距离计算误差,如果射击者在5秒窗口内没有射击,则也会传播错误
我的网络似乎没有正确学习,我想知道是否有人能指出我正确的方向。
这是我的代码:
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.util.Random;
import java.util.stream.DoubleStream;
import javax.swing.JFrame;
public class ClassMain extends JFrame{
int ShooterLocationX = 100;
int ShooterLocationY = 150;
int shooterDiameter = 35;
int targetX = 525;
int targetY = 100;
int targetWidth = 15;
int targetLength = 50;
double error;
double targetVectorX=ShooterLocationX + 100;
double targetVectorY = ShooterLocationY+shooterDiameter/2;
boolean direction = false;
boolean ShotFired = false;
boolean rotateDirection;
boolean start =true;
boolean shotbeenFires;
double ShooterRotation = 0;
double[] outputs1 = new double[4];
double[] outputs2 = new double[8];
double[] outputs3 = new double[8];
double[] outputs4 = new double[3];
Percepton[] input = new Percepton[4];
Percepton[] hidden = new Percepton[8];
Percepton[] hidden2 = new Percepton[8];
Percepton[] output = new Percepton[3];
Shooter shot = new Shooter(ShooterRotation);
Random random = new Random();
long StartTime = 0;
long CurrTime = 0;
public static void main(String []args){
ClassMain CM= new ClassMain();
CM.setup();
}
public void setup(){
setSize(900,300);
setVisible(true);
setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
setResizable(false);
createBufferStrategy(2);
repaint();
for(int i = 0; i < 4; i ++){
input[i] = new Percepton();
input[i].setInputNumber(1);
}
for(int i =0;i < 8;i++){
hidden[i] = new Percepton();
hidden[i].setInputNumber(input.length);
}
for(int i =0;i < 8;i++){
hidden2[i] = new Percepton();
hidden2[i].setInputNumber(input.length);
}
for(int i = 0;i < 3; i++){
output[i] = new Percepton();
output[i].setInputNumber(hidden.length);
}
}
public void run(){
double[] temp = new double[10];
if(start){
StartTime = System.currentTimeMillis();
start = false;
}
CurrTime = System.currentTimeMillis();
if(CurrTime-StartTime > 5000){
ShooterLocationX=random.nextInt(400);
ShooterLocationY=random.nextInt(275);
shot.newX = ShooterLocationX+(shooterDiameter/2)-5;
shot.newY = ShooterLocationY+(shooterDiameter/2)-5;
shot.windSpeed = (float)(random.nextInt(200)-100)/100;
start = true;
if(!shotbeenFires){
error+=15;
}
for(int n = 0; n < output.length;n++){
output[n].learn(n, error, output[n].weights[n]);
}
for(int n = 0; n < hidden2.length;n++){
hidden2[n].learn(n, error, hidden2[n].weights[n]);
}
for(int n = 0; n < hidden.length;n++){
hidden[n].learn(n, error, hidden[n].weights[n]);
}
for(int n = 0; n < input.length;n++){
input[n].learn(n, error, input[n].weights[n]);
}
error = 0;
shotbeenFires = false;
}
outputs1[0] = input[0].sigmoid((((Math.max(ShooterLocationX,targetX)-Math.min(ShooterLocationX,targetX))/300)*2)*input[0].weights[0],1*input[0].biasWeight);
outputs1[1] = input[1].sigmoid((((Math.max(ShooterLocationY,targetY)-Math.min(ShooterLocationY,targetY))/150)*2)*input[1].weights[0],1*input[1].biasWeight);
outputs1[2] = input[2].sigmoid((ShooterRotation)*input[2].weights[0],1*input[2].biasWeight);
outputs1[3] = input[3].sigmoid((shot.windSpeed)*input[3].weights[0],1*input[3].biasWeight);
System.out.println("\n");
for(int n = 0;n < hidden.length;n++){
for(int i = 0;i<temp.length;i++){
temp[i] = 0;
}
for(int i = 0;i < outputs1.length;i++){
temp[i] = outputs1[i]*hidden[n].weights[i];
}
outputs2[n] = (float)DoubleStream.of(temp).sum();
outputs2[n] = hidden[1].sigmoid(outputs2[n],1*hidden[n].biasWeight);
//System.out.print(outputs2[n]);
}
System.out.println("\n");
for(int n = 0;n < hidden2.length;n++){
for(int i = 0;i<temp.length;i++){
temp[i] = 0;
}
for(int i = 0;i < outputs2.length;i++){
temp[i] = outputs2[i]*hidden2[n].weights[i];
}
outputs3[n] = (float)DoubleStream.of(temp).sum();
outputs3[n] = hidden2[1].sigmoid(outputs3[n],1*hidden2[n].biasWeight);
//System.out.print(outputs3[n]);
}
System.out.println("\n");
for(int n = 0;n < output.length;n++){
for(int i = 0;i<temp.length;i++){
temp[i] = 0;
}
for(int i = 0;i < outputs1.length;i++){
temp[i] = outputs3[n]*hidden[n].weights[i];
}
outputs4[n] = (float)DoubleStream.of(temp).sum();
outputs4[n] = output[n].sigmoid(outputs4[n],1*output[n].biasWeight);
//System.out.print(outputs4[n]);
}
if(!ShotFired){
System.out.println("\n");
if(outputs4[0] > 0.5){
ShotFired = true;
shot.rotation = ShooterRotation;
}
if (outputs4[1] > 0.5){
ShooterRotation = ShooterRotation + outputs3[1];
}
if(outputs4[2] > 0.5){
ShooterRotation =ShooterRotation - outputs3[2];
}
if(ShooterRotation*360 > 360){
ShooterRotation = ((ShooterRotation * 360) + 360) /360;
}else if(ShooterRotation * 360 < 0){
ShooterRotation = ((ShooterRotation*360)-360) /360;
}
}
for(int i = 0; i < 1; i++){
if(ShotFired){
shotbeenFires = true;
shot.newLocationX(shot.newX);
shot.newLocationY(shot.newY);
}
if(targetY < 20 ||targetY+targetLength >300){
direction = !direction;
}
if(direction){
targetY += 8;
}else{
targetY -= 8;
}
if(shot.newX + (15)>= targetX && shot.newX + 15 < targetX +targetWidth && shot.newY + 15> targetY && shot.newY <targetY + targetLength){
System.out.println("Target Hit");
error -= 10;
shot.newX = ShooterLocationX + shooterDiameter/2-5;
shot.newY = ShooterLocationY + shooterDiameter/2-5;
ShotFired = false;
}else if(shot.newX > 600|| shot.newX<10||shot.newY<10||shot.newY> 290){
System.out.println("MISS");
error += (shot.newY - targetY);
shot.newX = ShooterLocationX + shooterDiameter/2-5;
shot.newY = ShooterLocationY + shooterDiameter/2-5;
ShotFired = false;
}
}
repaint();
}
public void paint(Graphics g){
targetVectorX = ShooterLocationX + shooterDiameter/2 + Math.cos(Math.toRadians(ShooterRotation)) * 100;
targetVectorY = ShooterLocationY + shooterDiameter/2 + Math.sin(Math.toRadians(ShooterRotation)) * 100;
super.paint(g);
Graphics2D g2 = (Graphics2D)g;
g2.fillOval(ShooterLocationX, ShooterLocationY, shooterDiameter, shooterDiameter);
g2.setPaint(Color.RED);
g2.drawLine(ShooterLocationX + shooterDiameter/2, ShooterLocationY+shooterDiameter/2, (int)targetVectorX, (int)targetVectorY);
g2.setPaint(Color.BLUE);
g2.fillRect(targetX, targetY, targetWidth, targetLength);
g2.setPaint(Color.RED);
g2.fillOval((int)shot.newX, (int)shot.newY, 10, 10);
g2.setPaint(Color.gray);
g2.fillRect(600, 0, 300,300);
for(int i = 0; i < input.length;i++){
if(outputs1[i] > 0.5){
g2.setPaint(Color.red);
}else{
g2.setPaint(Color.white);
}
g2.drawOval(625+i*35, 250, 25, 25);
for(int j = 0;j<hidden.length;j++){
g2.drawLine(640+i*35, 250, 615+j*30, 195);
}
}
for(int i = 0; i < hidden.length;i++){
if(outputs2[i] > 0.5){
g2.setPaint(Color.red);
}else{
g2.setPaint(Color.white);
}
g2.drawOval(605+i*30, 175, 20, 20);
for(int j = 0;j<hidden2.length;j++){
g2.drawLine(615+i*30, 175, 615+j*30, 145);
}
}
for(int i = 0; i < hidden2.length;i++){
if(outputs3[i] > 0.5){
g2.setPaint(Color.red);
}else{
g2.setPaint(Color.white);
}
g2.drawOval(605+i*30, 125, 20, 20);
for(int j = 0;j<output.length;j++){
g2.drawLine(615+i*30, 125, 697+j*50, 75);
}
}
for(int i = 0; i < output.length;i++){
if(outputs4[i] > 0.5){
g2.setPaint(Color.red);
}else{
g2.setPaint(Color.white);
}
g2.drawOval(685+i*50, 50, 25, 25);
}
try {
Thread.sleep(100);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
run();
}
}
public class Shooter {
double windSpeed = 0.5;
int distance = 0;
double newX;
double newY;
double rotation;
double temp = 0;
public Shooter(double rotate){
rotation = rotate*360;
windSpeed = windSpeed*5;
}
public void newLocationX(double XPrev){
newX = XPrev + Math.cos(Math.toRadians(rotation)) * 20;
}
public void newLocationY(double YPrev){
newY = YPrev + Math.sin(Math.toRadians(rotation)) * 20;
newY =newY - windSpeed;
}
}
import java.util.Random;
public class Percepton {
int numOfInputs = 100;
float[] weights = new float[numOfInputs];
int sum;
double learningRate = 0.1;
float biasWeight;
float a;
float out;
Random random = new Random();
public void setWeight(int index,float weight){
weights[index] = weight;
}
public void setInputNumber(int inputs){
numOfInputs = inputs;
for(int i = 0; i < inputs; i++){
weights[i] = (float)(random.nextInt(400)-200)/200;
biasWeight = (float)(random.nextInt(400)-200)/200;
//System.out.println(weights[i]);
}
}
public void learn(int index, double error, float value){
weights[index] += (float)learningRate * (float)error * value;
}
public double sigmoid(double input,double bias){
return (float) (1/(1+Math.exp(-input*bias)));
}
}