试图建立一个神经网络,结构有些问题

时间:2014-04-05 22:23:55

标签: java neural-network

我正在尝试建立一个非常简单的神经网络,它将采用多层感知器方法。

我有4个班级:

  

节点:只有一个没有任何内容的构造函数,以及一个输出方法(我仍然需要定义)

     

InputNode :将根据数据从.data文件读取数据并设置输入节点(双重类型)

     

HiddenLayer :此图层的大小将是每个Input ArrayList长度的2/3

     

NeuralNetwork :这将是主要的执行类,我将在其中设置具有随机权重的2-d双数组。这些权重将乘以输入数据以得到总和。这里会出现错误和列车方法。

我的问题与InputNode和HiddenLayer类有关。我希望InputNode类从.data文件中读取数据,并将.data中的每一行设置为13个输入节点。 .data文件中每行的最后一个数字代表0-3的预测数:对于某些东西为负或在3以上:对某些东西为正

示例:

如果我有13个输入,那么我应该有大约8个隐藏节点。每个隐藏节点将接收所有输入节点,并将每个输入节点的输入值乘以该特定隐藏节点的权重,并得出1或-1值。

我该如何设置?因为我最初设置我的NeuralNetwork类来设置输入ArrayList,所以有一些麻烦,但现在已经决定创建单独的类更好。

以下是示例数据集的样子:

Data

以下是我的Node Class的代码:

public class Node {

public Node(){

}

public void output(){
    //Don't know what to put here? Guess it can be modified for hiddenlayer and
        // Node classes 
}

}

以下是我的InputNode类的代码:

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public class InputNode extends Node {

//declare arraylist for input nodes
private List<Node> inputs = new ArrayList<Node>();

//Constuctor for input nodes
public InputNode(File f){

    List<Node> tempInput = new ArrayList<Node>();
    this.inputs = tempInput;

    try {
        @SuppressWarnings("resource")
        Scanner inFile = new Scanner(f);

        //While there is another line in inFile.
        while (inFile.hasNextLine()){
            //Store that line into String line
            String line = inFile.nextLine();
            //Parition values separated by a comma
            String[] columns = line.split(",");
            /*System.out.println(columns.length); //Test code to see length of each column
            * code works and prints 14
            * */
            //create a double array of size columns
            double[] rows = new double[columns.length];
            for (int i = 0; i < columns.length; i++){
                //For each row...
                if (!columns[i].equals("?")){
                    //If the values in each row do not equal "?"
                    //Set rows[i] to the values in column[i]
                    rows[i] = Double.parseDouble(columns[i]);
                }
                else {
                    rows[i] = 0;
                }
            }
            inputs.add(rows);
            }

    } catch (FileNotFoundException e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }

}




}

1 个答案:

答案 0 :(得分:2)

我认为这个问题的关键是如何识别'?'来自混合String数组的标记。我不确定这是你想要的,但我建议一种可能性。

跟随代码交换'?'标记为加倍'0.0'。

我已经更改了你的一些代码。 - 将ArrayList Generic从'Node'更改为'double []',因为它不能将双数组(代码中的'rows')分配给Object ArrayList。 - 添加try / catch短语,因为它无法将文字符号强制转换为Integer或Double。所以首先我们需要整理这个字符串是文字或数字。

public class InputNode {

    private final String REGEX = ",";
    private final String Q_MARKS = "?";
    private List<double[]> input;

    @SuppressWarnings("resource")
    public InputNode(File f){
        this.input = new ArrayList<double[]>();
        try{
            Scanner inFile = new Scanner(f);
            int cntLines = 0;

            while(inFile.hasNextLine()){
                String line = inFile.nextLine();
                String columns[] = line.split(REGEX);

                double rows[] = new double[columns.length];

                for(int i = 0; i < columns.length; i++){
                    //if String is '?' then 0.0, else parseDouble
                    if(!columns[i].equals(Q_MARKS)){
                        try{
                            rows[i] = Double.parseDouble(columns[i]);
                        }catch(Exception e){
                            System.out.println("data is not '?'");
                            continue;
                        }
                    }else{
                        System.out.println("data is '?' mark");
                        rows[i] = 0;
                    }
                }

                input.add(rows);
                cntLines++;
                System.out.println("num : " + cntLines + ", value : " + Arrays.toString(rows));
            }
            System.out.println("total Lines read : " + cntLines);
        }catch(Exception e){
            e.printStackTrace();
        }
    }
}; 

我的测试a.data文件是

1,2,3,4,5,6,7,8,9,10,11,12,3,0
1,2,3,4,5,6,7,8,9,10,?,12,6,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,1
1,2,3,4,5,6,7,18,19,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,3,0
1,2,3,4,5,6,7,8,9,10,?,?,?,1
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,2,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0

最后输出是:

Riddles +_+
num : 1, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 3.0, 0.0]
data is '?' mark
num : 2, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 12.0, 6.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 3, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 4, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 5, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 6, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 1.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 7, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 18.0, 19.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
num : 8, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 3.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 9, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 1.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 10, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
num : 11, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 2.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 12, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 13, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
total Lines read : 13

兑换所有'?'标记为'0.0',你可以用数学计算:D