最近我一直在刷机器学习,因此决定使用反向传播算法在Java中实现基本的神经网络。我已经完成了数学并检查了其他各种教程,但仍然遇到问题。对这篇文章的大小道歉。
在详细介绍算法之前,我先让您了解我一直在测试的问题。
测试问题1:
具有线性激活的单个输出神经元,学习函数x/2 + 2
的回归。这很好用,但还没有真正使用反向传播。
算法有效,收敛到接近零的误差(没有pic,因为我不能发布超过2个链接)。
测试问题2:
我的下一个测试是学习XOR问题。为此,我尝试了一个简单的网络,其中包含2个输入节点,2个隐藏节点和2个输出节点(输入节点仅提供输入且未经过训练)。
Algorithm always gets stuck on an average of 0.5 error
我运行算法的时代数并不重要,所有错误似乎都会收敛到这一点,并且网络性能很差。
实施
为了实现该算法,我将节点表示为对象,并且还具有表示激活函数的对象。
public class LogisticActivationFunction implements ActivationFunction {
@Override
public double apply(double in) {
return 1.0 / (1.0 + Math.exp(-in));
}
@Override
public double applyDerivative(double in) {
double sig = apply(in);
return sig * (1.0 - sig);
}
}
首先,前馈流程如下运行:
public List<Double> evaluate(List<Double> inputs, boolean training) {
// Set the weights in the first layer.
setInputWeights(inputs);
// Iterate through non-input layers one by one and evaluate.
NodeLayer previousLayer = layers.get(0);
for (int layerIndex = 1; layerIndex < layers.size(); layerIndex++) {
NodeLayer layer = layers.get(layerIndex);
for (int nodeIndex = 0; nodeIndex < layer.size(); nodeIndex++) {
Node node = layer.get(nodeIndex);
evaluateNode(node, previousLayer, training);
}
previousLayer = layer;
}
return getOutputWeights();
}
private void evaluateNode(Node node, NodeLayer previousLayer, boolean training) {
double sum = node.getBias();
// Create sum from all connected nodes.
for (int link : node.links()) {
if (training) {
previousLayer.get(link).registerDownstreamNode(node.getId());
}
sum += node.getUpstreamLinkStrength(link) * previousLayer.get(link).getOutput();
}
// apply the activation function.
double activation = node.getActivation().apply(sum);
node.setHiddenNode(sum, activation);
}
接下来,错误值通过网络向后传播:
protected void backPropogate(List<Double> correct) {
//float error = norm(correct, getOutputWeights());
// Final layer error.
NodeLayer outputLayer = getOutputLayer();
List<Double> output = getOutputWeights();
for (int i = 0; i < outputLayer.size(); i++) {
// Calculate error on the ith output.
double error = correct.get(i) - output.get(i);
System.out.println("error " + i + " = " + error + " = " + correct.get(i) + " - " + output.get(i));
// Set the delta to the error in dimension i multiplied by the activation derivative of the input.
Node node = outputLayer.get(i);
node.setDelta(error * node.getActivation().applyDerivative(node.getInput()));
}
NodeLayer layer = outputLayer.getUpstream(this);
while (layer != getInputLayer()) {
for (Node node : layer) {
double sum = 0;
for (Node downstream : node.downstreamNodes(this, layer)) {
sum += downstream.getDelta() * downstream.getUpstreamLinkStrength(node.getId());
}
node.setDelta(sum * node.getActivation().applyDerivative(node.getInput()));
}
layer = layer.getUpstream(this);
}
}
最后,使用梯度下降更新权重。注意,我使用了负面错误,所以这可以通过添加delta *学习率*输出来实现。
private void updateParameters(double learningRate) {
for (NodeLayer layer : this) {
if (layer == getInputLayer()) {
continue;
}
for (Node node : layer) {
double oldBias = node.getBias();
node.offsetBias(node.getDelta() * learningRate);
for (Node upstream : node.upstreamNodes(this, layer)) {
double oldW = node.getUpstreamLinkStrength(upstream);
node.offsetWeight(upstream.getId(), learningRate * node.getDelta() * upstream.getOutput());
}
}
}
}
为了将这些结合起来,我使用火车方法:
public void trainExample(List<Double> inputs, List<Double> correct, double learningRate) {
System.out.println("training example... " + Data.toString(inputs) + " -> " + Data.toString(correct));
evaluate(inputs, true);
backPropogate(correct);
updateParameters(learningRate);
}
为了训练集,我使用以下逻辑:
public List<Double> train(NodeNetwork network, List<List<Double>> trainingInput, List<List<Double>> trainingLabels, double learningRate, int epochs, boolean verbose) {
List<Double> errorLog = new ArrayList<>();
for (int i = 0; i < epochs; i++) {
for (int j = 0; j < trainingInput.size(); j++) {
int example = random.nextInt(trainingInput.size());
network.trainExample(trainingInput.get(example), trainingLabels.get(example), learningRate);
}
if (verbose) {
double error = network.checkErrorSet(trainingInput, trainingLabels);
errorLog.add(error);
System.out.println(i + " " + error);
}
}
return errorLog;
}
有没有人对如何让这个工作有任何想法?我在最后一天做了各种检查,似乎没有接近答案。
代码可以在我的github(sami016)上查看,由于URL限制我无法链接。
如果有人能指出我正确的方向,我真的很感激。谢谢你的帮助!