我尝试在XOR问题上使用ReLU激活功能来查看它的性能,因为我看到很多帖子和页面说它比sigmoid和其他更好。我用了这段代码:
/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package classifier;
import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;
/**
* This sample shows how to create, train, save and load simple Multi Layer Perceptron for the XOR problem.
* This sample shows basics of Neuroph API.
* @author Zoran Sevarac <sevarac@gmail.com>
*/
public class XorMultiLayerPerceptronSample implements LearningEventListener {
public static void main(String[] args) {
new XorMultiLayerPerceptronSample().run();
}
/**
* Runs this sample
*/
public void run() {
// create training set (logical XOR function)
DataSet trainingSet = new DataSet(2, 1);
trainingSet.addRow(new DataSetRow(new double[]{0, 0}, new double[]{0}));
trainingSet.addRow(new DataSetRow(new double[]{0, 1}, new double[]{1}));
trainingSet.addRow(new DataSetRow(new double[]{1, 0}, new double[]{1}));
trainingSet.addRow(new DataSetRow(new double[]{1, 1}, new double[]{0}));
// create multi layer perceptron
MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.RELU, 2, 3, 1);
myMlPerceptron.setLearningRule(new BackPropagation());
// enable batch if using MomentumBackpropagation
// if( myMlPerceptron.getLearningRule() instanceof MomentumBackpropagation )
// ((MomentumBackpropagation)myMlPerceptron.getLearningRule()).setBatchMode(false);
LearningRule learningRule = myMlPerceptron.getLearningRule();
learningRule.addListener(this);
// learn the training set
System.out.println("Training neural network...");
myMlPerceptron.learn(trainingSet);
// test perceptron
System.out.println("Testing trained neural network");
testNeuralNetwork(myMlPerceptron, trainingSet);
// save trained neural network
myMlPerceptron.save("myMlPerceptron.nnet");
// load saved neural network
NeuralNetwork loadedMlPerceptron = NeuralNetwork.createFromFile("myMlPerceptron.nnet");
// test loaded neural network
System.out.println("Testing loaded neural network");
testNeuralNetwork(loadedMlPerceptron, trainingSet);
}
/**
* Prints network output for the each element from the specified training set.
* @param neuralNet neural network
* @param trainingSet training set
*/
public static void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {
for(DataSetRow testSetRow : testSet.getRows()) {
neuralNet.setInput(testSetRow.getInput());
neuralNet.calculate();
double[] networkOutput = neuralNet.getOutput();
System.out.print("Input: " + Arrays.toString( testSetRow.getInput() ) );
System.out.println(" Output: " + Arrays.toString( networkOutput) );
}
}
@Override
public void handleLearningEvent(LearningEvent event) {
BackPropagation bp = (BackPropagation)event.getSource();
if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED)
System.out.println(bp.getCurrentIteration() + ". iteration : "+ bp.getTotalNetworkError());
}
}
我将RELU添加到TransferFunctionType中:
/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neuroph.util;
import org.neuroph.core.transfer.Gaussian;
import org.neuroph.core.transfer.Linear;
import org.neuroph.core.transfer.Log;
import org.neuroph.core.transfer.Ramp;
import org.neuroph.core.transfer.RectifiedLinear;
import org.neuroph.core.transfer.Sgn;
import org.neuroph.core.transfer.Sigmoid;
import org.neuroph.core.transfer.Sin;
import org.neuroph.core.transfer.Step;
import org.neuroph.core.transfer.Tanh;
import org.neuroph.core.transfer.Trapezoid;
/**
* Contains transfer functions types and labels.
*/
public enum TransferFunctionType {
LINEAR("Linear"),
RAMP("Ramp"),
STEP("Step"),
SIGMOID("Sigmoid"),
TANH("Tanh"),
GAUSSIAN("Gaussian"),
TRAPEZOID("Trapezoid"),
SGN("Sgn"),
SIN("Sin"),
LOG("Log"),
RELU("ReLU");
private String typeLabel;
private TransferFunctionType(String typeLabel) {
this.typeLabel = typeLabel;
}
public String getTypeLabel() {
return typeLabel;
}
public Class getTypeClass() {
switch (this) {
case LINEAR:
return Linear.class;
case STEP:
return Step.class;
case RAMP:
return Ramp.class;
case SIGMOID:
return Sigmoid.class;
case TANH:
return Tanh.class;
case TRAPEZOID:
return Trapezoid.class;
case GAUSSIAN:
return Gaussian.class;
case SGN:
return Sgn.class;
case SIN:
return Sin.class;
case LOG:
return Log.class;
case RELU:
return RectifiedLinear.class;
} // switch
return null;
}
}
我正在使用neuroph 2.92,当我跑步时,它被卡在总网络错误= 0.25。我也使用TANH并且它被卡在1.25而SIGMOID很容易达到总误差&lt; 0.01 他们怎么了?或者我在某处犯了错误。感谢