代码使用Neuroph给了我VectorSizeMismatchException

时间:2019-06-05 08:06:05

标签: java artificial-intelligence neuroph

我知道如何使用Java进行编码,但是我是Neuroph的新手,我得到了VectorSizeMismatchException的这些代码。

主要:https://pastebin.com/dntWRMZN

public static void main(String[] args) {
    AiManager.trainNeuralNetwork(AiManager.initilizeNetwork());
}

经理:    https://pastebin.com/csWsiVvt

import org.neuroph.core.Layer;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.Neuron;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.Perceptron;
import org.neuroph.util.ConnectionFactory;
import org.neuroph.util.NeuralNetworkType;

public class AiManager {

  public static NeuralNetwork<?> initilizeNetwork() {
    Layer inputLayer = new Layer();
    inputLayer.addNeuron(new Neuron());
    inputLayer.addNeuron(new Neuron());

    Layer hiddenLayerOne = new Layer();
    hiddenLayerOne.addNeuron(new Neuron());
    hiddenLayerOne.addNeuron(new Neuron());
    hiddenLayerOne.addNeuron(new Neuron());
    hiddenLayerOne.addNeuron(new Neuron());

    Layer hiddenLayerTwo = new Layer();
    hiddenLayerTwo.addNeuron(new Neuron());
    hiddenLayerTwo.addNeuron(new Neuron());
    hiddenLayerTwo.addNeuron(new Neuron());
    hiddenLayerTwo.addNeuron(new Neuron());

    Layer outputLayer = new Layer();
    outputLayer.addNeuron(new Neuron());

    NeuralNetwork<?> ann = new Perceptron(2, 1);

    ann.setInputNeurons(inputLayer.getNeurons());
    ann.setOutputNeurons(outputLayer.getNeurons());

    ann.addLayer(0, inputLayer);
    ann.addLayer(1, hiddenLayerOne);

    ConnectionFactory.fullConnect(ann.getLayerAt(0), ann.getLayerAt(1));

    ann.addLayer(2, hiddenLayerTwo);

    ConnectionFactory.fullConnect(ann.getLayerAt(1), ann.getLayerAt(2));

    ann.addLayer(3, outputLayer);

    ConnectionFactory.fullConnect(ann.getLayerAt(2), ann.getLayerAt(3));

    ConnectionFactory.fullConnect(ann.getLayerAt(0), ann.getLayerAt(ann.getLayersCount()-1), false);

    ann.setNetworkType(NeuralNetworkType.MULTI_LAYER_PERCEPTRON);

    ann.setInputNeurons(inputLayer.getNeurons());
    ann.setOutputNeurons(outputLayer.getNeurons());

    return ann;
  }

  public static NeuralNetwork<?> trainNeuralNetwork(NeuralNetwork<?> ann) {
    int inputSize = 2;
    int outputSize = 1;
    DataSet ds = new DataSet(inputSize, outputSize);

    DataSetRow rOne = new DataSetRow(new double[] { 0, 1 }, new double[] { 1 });

    ds.addRow(rOne);

    DataSetRow rTwo = new DataSetRow(new double[] { 1, 1 }, new double[] { 0 });

    ds.addRow(rTwo);

    DataSetRow rThree = new DataSetRow(new double[] { 0, 0 }, new double[] { 0 });

    ds.addRow(rThree);

    DataSetRow rFour = new DataSetRow(new double[] { 1, 0 }, new double[] { 1 });

    ds.addRow(rFour);

    ann.learn(ds);

    return ann;
  }
}

1 个答案:

答案 0 :(得分:0)

此错误是由于Neuroph API的编码方式所致,如您所见,您添加了两次表达式ann.setInputNeurons(inputLayer.getNeurons());ann.setOutputNeurons(outputLayer.getNeurons());,但是如果您在前后打印表达式ann.getInputNeurons().size()这些表达式被调用后,您将看到,每次“设置”输入神经元时,您都会添加新神经元。

请参阅NeuralNetwork类的来源:

/**
 * Sets input neurons
 *
 * @param inputNeurons array of input neurons
 */
public void setInputNeurons(List<Neuron> inputNeurons) {
    for (Neuron neuron : inputNeurons) {
        this.inputNeurons.add(neuron);
    }
}

理论上删除这4行应该可以消除错误。