我正在尝试建立一个非常简单的神经网络,它将采用多层感知器方法。
我有4个班级:
节点:只有一个没有任何内容的构造函数,以及一个输出方法(我仍然需要定义)
InputNode :将根据数据从.data文件读取数据并设置输入节点(双重类型)
HiddenLayer :此图层的大小将是每个Input ArrayList长度的2/3
NeuralNetwork :这将是主要的执行类,我将在其中设置具有随机权重的2-d双数组。这些权重将乘以输入数据以得到总和。这里会出现错误和列车方法。
我的问题与InputNode和HiddenLayer类有关。我希望InputNode类从.data文件中读取数据,并将.data中的每一行设置为13个输入节点。 .data文件中每行的最后一个数字代表0-3的预测数:对于某些东西为负或在3以上:对某些东西为正
示例:的
如果我有13个输入,那么我应该有大约8个隐藏节点。每个隐藏节点将接收所有输入节点,并将每个输入节点的输入值乘以该特定隐藏节点的权重,并得出1或-1值。
我该如何设置?因为我最初设置我的NeuralNetwork类来设置输入ArrayList,所以有一些麻烦,但现在已经决定创建单独的类更好。
以下是示例数据集的样子:
以下是我的Node Class的代码:
public class Node {
public Node(){
}
public void output(){
//Don't know what to put here? Guess it can be modified for hiddenlayer and
// Node classes
}
}
以下是我的InputNode类的代码:
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
public class InputNode extends Node {
//declare arraylist for input nodes
private List<Node> inputs = new ArrayList<Node>();
//Constuctor for input nodes
public InputNode(File f){
List<Node> tempInput = new ArrayList<Node>();
this.inputs = tempInput;
try {
@SuppressWarnings("resource")
Scanner inFile = new Scanner(f);
//While there is another line in inFile.
while (inFile.hasNextLine()){
//Store that line into String line
String line = inFile.nextLine();
//Parition values separated by a comma
String[] columns = line.split(",");
/*System.out.println(columns.length); //Test code to see length of each column
* code works and prints 14
* */
//create a double array of size columns
double[] rows = new double[columns.length];
for (int i = 0; i < columns.length; i++){
//For each row...
if (!columns[i].equals("?")){
//If the values in each row do not equal "?"
//Set rows[i] to the values in column[i]
rows[i] = Double.parseDouble(columns[i]);
}
else {
rows[i] = 0;
}
}
inputs.add(rows);
}
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
答案 0 :(得分:2)
我认为这个问题的关键是如何识别'?'来自混合String数组的标记。我不确定这是你想要的,但我建议一种可能性。
跟随代码交换'?'标记为加倍'0.0'。
我已经更改了你的一些代码。 - 将ArrayList Generic从'Node'更改为'double []',因为它不能将双数组(代码中的'rows')分配给Object ArrayList。 - 添加try / catch短语,因为它无法将文字符号强制转换为Integer或Double。所以首先我们需要整理这个字符串是文字或数字。
public class InputNode {
private final String REGEX = ",";
private final String Q_MARKS = "?";
private List<double[]> input;
@SuppressWarnings("resource")
public InputNode(File f){
this.input = new ArrayList<double[]>();
try{
Scanner inFile = new Scanner(f);
int cntLines = 0;
while(inFile.hasNextLine()){
String line = inFile.nextLine();
String columns[] = line.split(REGEX);
double rows[] = new double[columns.length];
for(int i = 0; i < columns.length; i++){
//if String is '?' then 0.0, else parseDouble
if(!columns[i].equals(Q_MARKS)){
try{
rows[i] = Double.parseDouble(columns[i]);
}catch(Exception e){
System.out.println("data is not '?'");
continue;
}
}else{
System.out.println("data is '?' mark");
rows[i] = 0;
}
}
input.add(rows);
cntLines++;
System.out.println("num : " + cntLines + ", value : " + Arrays.toString(rows));
}
System.out.println("total Lines read : " + cntLines);
}catch(Exception e){
e.printStackTrace();
}
}
};
我的测试a.data文件是
1,2,3,4,5,6,7,8,9,10,11,12,3,0
1,2,3,4,5,6,7,8,9,10,?,12,6,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,1
1,2,3,4,5,6,7,18,19,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,3,0
1,2,3,4,5,6,7,8,9,10,?,?,?,1
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,2,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
1,2,3,4,5,6,7,8,9,10,?,?,?,0
最后输出是:
Riddles +_+
num : 1, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 3.0, 0.0]
data is '?' mark
num : 2, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 12.0, 6.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 3, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 4, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 5, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 6, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 1.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 7, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 18.0, 19.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
num : 8, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 3.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 9, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 1.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 10, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
num : 11, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 2.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 12, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
data is '?' mark
data is '?' mark
data is '?' mark
num : 13, value : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 0.0]
total Lines read : 13
兑换所有'?'标记为'0.0',你可以用数学计算:D