使用ReLU传输功能,总网络错误卡在XOR门中

时间:2017-04-09 11:40:28

标签: java machine-learning neural-network sigmoid activation-function

我尝试在XOR问题上使用ReLU激活功能来查看它的性能,因为我看到很多帖子和页面说它比sigmoid和其他更好。我用了这段代码:

/**
 * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package classifier;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;

/**
 * This sample shows how to create, train, save and load simple Multi Layer Perceptron for the XOR problem.
 * This sample shows basics of Neuroph API.
 * @author Zoran Sevarac <sevarac@gmail.com>
 */
public class XorMultiLayerPerceptronSample implements LearningEventListener {

    public static void main(String[] args) {
        new XorMultiLayerPerceptronSample().run();
    }

    /**
     * Runs this sample
     */
    public void run() {

        // create training set (logical XOR function)
        DataSet trainingSet = new DataSet(2, 1);
        trainingSet.addRow(new DataSetRow(new double[]{0, 0}, new double[]{0}));
        trainingSet.addRow(new DataSetRow(new double[]{0, 1}, new double[]{1}));
        trainingSet.addRow(new DataSetRow(new double[]{1, 0}, new double[]{1}));
        trainingSet.addRow(new DataSetRow(new double[]{1, 1}, new double[]{0}));

        // create multi layer perceptron
        MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.RELU, 2, 3, 1);

        myMlPerceptron.setLearningRule(new BackPropagation());

        // enable batch if using MomentumBackpropagation
//        if( myMlPerceptron.getLearningRule() instanceof MomentumBackpropagation )
//          ((MomentumBackpropagation)myMlPerceptron.getLearningRule()).setBatchMode(false);

        LearningRule learningRule = myMlPerceptron.getLearningRule();
        learningRule.addListener(this);

        // learn the training set
        System.out.println("Training neural network...");
        myMlPerceptron.learn(trainingSet);

        // test perceptron
        System.out.println("Testing trained neural network");
        testNeuralNetwork(myMlPerceptron, trainingSet);

        // save trained neural network
        myMlPerceptron.save("myMlPerceptron.nnet");

        // load saved neural network
        NeuralNetwork loadedMlPerceptron = NeuralNetwork.createFromFile("myMlPerceptron.nnet");

        // test loaded neural network
        System.out.println("Testing loaded neural network");
        testNeuralNetwork(loadedMlPerceptron, trainingSet);
    }

    /**
     * Prints network output for the each element from the specified training set.
     * @param neuralNet neural network
     * @param trainingSet training set
     */
    public static void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {

        for(DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();

            System.out.print("Input: " + Arrays.toString( testSetRow.getInput() ) );
            System.out.println(" Output: " + Arrays.toString( networkOutput) );
        }
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED)
            System.out.println(bp.getCurrentIteration() + ". iteration : "+ bp.getTotalNetworkError());
    }    

}

我将RELU添加到TransferFunctionType中:

  /**
 * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.neuroph.util;

import org.neuroph.core.transfer.Gaussian;
import org.neuroph.core.transfer.Linear;
import org.neuroph.core.transfer.Log;
import org.neuroph.core.transfer.Ramp;
import org.neuroph.core.transfer.RectifiedLinear;
import org.neuroph.core.transfer.Sgn;
import org.neuroph.core.transfer.Sigmoid;
import org.neuroph.core.transfer.Sin;
import org.neuroph.core.transfer.Step;
import org.neuroph.core.transfer.Tanh;
import org.neuroph.core.transfer.Trapezoid;

/**
 * Contains transfer functions types and labels.
 */
public enum TransferFunctionType {
    LINEAR("Linear"),
    RAMP("Ramp"),
    STEP("Step"),
    SIGMOID("Sigmoid"),
    TANH("Tanh"),
    GAUSSIAN("Gaussian"),
    TRAPEZOID("Trapezoid"),
    SGN("Sgn"), 
        SIN("Sin"), 
        LOG("Log"),
        RELU("ReLU");

    private String typeLabel;

    private TransferFunctionType(String typeLabel) {
        this.typeLabel = typeLabel;
    }

    public String getTypeLabel() {
        return typeLabel;
    }

        public Class getTypeClass() {
            switch (this) {
                case LINEAR:
            return Linear.class;
        case STEP:
            return Step.class;
        case RAMP:
            return Ramp.class;
        case SIGMOID:
            return Sigmoid.class;
        case TANH:
            return Tanh.class;
        case TRAPEZOID:
            return Trapezoid.class;
        case GAUSSIAN:
            return Gaussian.class;
        case SGN:
            return Sgn.class; 
        case SIN:
            return Sin.class;                     
        case LOG:
            return Log.class;
        case RELU:
            return RectifiedLinear.class;
        } // switch

            return null;

            }

        }

我正在使用neuroph 2.92,当我跑步时,它被卡在总网络错误= 0.25。我也使用TANH并且它被卡在1.25而SIGMOID很容易达到总误差&lt; 0.01 他们怎么了?或者我在某处犯了错误。感谢

0 个答案:

没有答案