我在Java中开发了一个简单的ANN,它使用反向传播和mini_batch随机梯度下降。 对于诸如XOR之类的简单问题或诸如Iris dataset之类的小数据集,它可以很好地工作。
#This is where I ask the user for input for a value
UI = print (float(input("Enter your value here: ")))
#Here I have numbers that I need to multiply the input by
FRT = (float(0.290949)
SRM = (float(0.281913)
#Here is the multiplication but this is where the issue occurrs
QV = (float("FRT"*"UI"))
SV = (float("SRM"*"UI"))
给出了严谨的结果:
FileReader file = new FileReader("iris1.txt");
Scanner s = new Scanner(file).useLocale(Locale.US);
double[][] x = new double[4][147];
double[][] y = new double[3][147];
for(int i=0; i<x[0].length; i++){
for(int j=0; j<x.length; j++){
x[j][i] = s.nextDouble();
}
}
for(int i=0; i<y[0].length; i++){
for(int j=0; j<y.length; j++){
y[j][i] = s.nextDouble();
}
}
double[][] test = {{5.0, 5.7, 5.9, 4.9},{3.3, 2.8, 3.0, 3.1},{1.4, 4.1, 5.1, 1.6},{0.2, 1.3, 1.8, 0.4}};
double[][] lbl = {{1, 0, 0, 1},{0, 1, 0, 0},{0, 0, 1, 0}};
NeuralNetwork2 NN = new NeuralNetwork2(x, y, 30, test, lbl);
NN.train(0.25, 0.0, 0.0, 100, 3); // rate, momentum, decay, epochs, batch size
double[][] result = NN.a_3(test);
for(int i=0; i<result.length; i++){
for(int j=0; j<result[0].length; j++){
System.out.print((float)result[i][j]+ " | ");
}
System.out.println();
}
但是对于更复杂的数据集,例如poker hand一个或MNIST数字识别数据集,它只输出随机预测。 我尝试了很多学习率和批量大小而没有看到任何改进。
Completed epoch 100 acc:1.0
0.9979585 | 0.002576745 | 6.41456E-5 | 0.9972225 |
0.0017186871 | 0.99348587 | 0.047361016 | 0.0023599463 |
2.234334E-6 | 0.0026515499 | 0.95999074 | 2.7198328E-6 |
此代码输出此奇怪的结果,所有列都相同
double[][] x = MNIST_reader.ImgReader("train-images.idx3-ubyte",6000);
double[][] y = MNIST_reader.LblReader("train-labels.idx1-ubyte",6000);
double[][] test = MNIST_reader.ImgReader("train-images.idx3-ubyte",10);
double[][] labels = MNIST_reader.LblReader("train-labels.idx1-ubyte",10);
System.out.println("Starting training");
NeuralNetwork2 NN = new NeuralNetwork2(x, y, 5);
NN.train(3.0, 0.0, 0.0, 10, 10);
System.out.println("Done!");
double[][] result = a_3(test);
for(int i=0; i<result.length; i++){
for(int j=0; j<result[0].length; j++){
System.out.print(result[i][j]+ " ");
}
System.out.println();
}
对这里的错误有什么想法?