我正在尝试编写我的第一个神经网络来玩连接四的游戏。 我使用 Java 和 deeplearning4j。 我试图实现一个遗传算法,但是当我训练网络一段时间时,网络的输出跳到 NaN 并且我无法分辨出我在哪里搞砸了如此严重的事情。 我将在下面发布所有 3 个类,其中 Game 是游戏逻辑和规则,VGFrame 是 UI,Main 是所有 nn 的东西。
我有一个包含 35 个神经网络的池,每次迭代我都会让最好的 5 个存活并繁殖,并稍微随机化新创建的神经网络。 为了评估网络,我让他们互相争斗,然后给胜者加分,之后输给分数。 由于我将一块石头放入已经满的列中,我希望神经网络至少能够在一段时间后按照规则玩游戏,但他们不能这样做。 我用谷歌搜索了 NaN 问题,它似乎是一个梯度梯度问题,但据我所知,这不应该发生在遗传算法中? 有什么想法可以查找错误或我的实施通常有什么问题吗?
主要
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
public class Main {
final int numRows = 7;
final int numColums = 6;
final int randSeed = 123;
MultiLayerNetwork[] models;
static Random random = new Random();
private static final Logger log = LoggerFactory.getLogger(Main.class);
final float learningRate = .8f;
int batchSize = 64; // Test batch size
int nEpochs = 1; // Number of training epochs
// --
public static Main current;
Game mainGame = new Game();
public static void main(String[] args) {
current = new Main();
current.frame = new VGFrame();
current.loadWeights();
}
private VGFrame frame;
private final double mutationChance = .05;
public Main() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).seed(randSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
.activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
.build();
models = new MultiLayerNetwork[35];
for (int i = 0; i < models.length; i++) {
models[i] = new MultiLayerNetwork(conf);
models[i].init();
}
}
public void addChip(int i, boolean b) {
if (mainGame.gameState == 0)
mainGame.addChip(i, b);
if (mainGame.gameState == 0) {
float[] f = Main.rowsToInput(mainGame.rows);
INDArray input = Nd4j.create(f);
INDArray output = models[0].output(input);
for (int i1 = 0; i1 < 7; i1++) {
System.out.println(i1 + ": " + output.getDouble(i1));
}
System.out.println("----------------");
mainGame.addChip(Main.getHighestOutput(output), false);
}
getFrame().paint(getFrame().getGraphics());
}
public void newGame() {
mainGame = new Game();
getFrame().paint(getFrame().getGraphics());
}
public void startTraining(int iterations) {
// --------------------------
for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
System.out.println("Iteration " + gameNumber + " of " + iterations);
float[] evaluation = new float[models.length];
for (int i = 0; i < models.length; i++) {
for (int j = 0; j < models.length; j++) {
if (i != j) {
Game g = new Game();
g.playFullGame(models[i], models[j]);
if (g.gameState == 1) {
evaluation[i] += 45;
evaluation[j] += g.turnNumber;
}
if (g.gameState == 2) {
evaluation[j] += 45;
evaluation[i] += g.turnNumber;
}
}
}
}
float[] evaluationSorted = evaluation.clone();
Arrays.sort(evaluationSorted);
// keep the best 4
int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
for (int i = 0; i < evaluation.length; i++) {
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
n1 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
n2 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
n3 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
n4 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
n5 = i;
}
models[0] = models[n1];
models[1] = models[n2];
models[2] = models[n3];
models[3] = models[n4];
models[4] = models[n5];
for (int i = 3; i < evaluationSorted.length; i++) {
// random parent/keep w8ts
double r = Math.random();
if (r > .3) {
models[i] = models[random.nextInt(3)].clone();
} else if (r > .1) {
models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
}
// Mutate
INDArray params = models[i].params();
models[i].setParams(mutate(params));
}
}
}
private INDArray mutate(INDArray params) {
double[] d = params.toDoubleVector();
for (int i = 0; i < d.length; i++) {
if (Math.random() < mutationChance)
d[i] += (Math.random() - .5) * learningRate;
}
return Nd4j.create(d);
}
private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
double[] d = m1.params().toDoubleVector();
double[] d2 = m2.params().toDoubleVector();
for (int i = 0; i < d.length; i++) {
if (Math.random() < .5)
d[i] += d2[i];
}
return Nd4j.create(d);
}
static int getHighestOutput(INDArray output) {
int x = 0;
for (int i = 0; i < 7; i++) {
if (output.getDouble(i) > output.getDouble(x))
x = i;
}
return x;
}
static float[] rowsToInput(byte[][] rows) {
float[] f = new float[7 * 6];
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 7; j++) {
// f[j + i * 7] = rows[j][i] / 2f;
f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
}
}
return f;
}
public void saveWeights() {
log.info("Saving model");
for (int i = 0; i < models.length; i++) {
File resourcesDirectory = new File("src/resources/model" + i);
try {
models[i].save(resourcesDirectory, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}
public void loadWeights() {
if (new File("src/resources/model0").exists()) {
for (int i = 0; i < models.length; i++) {
File resourcesDirectory = new File("src/resources/model" + i);
try {
models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
System.out.println("col: " + models[0].params().shapeInfoToString());
}
public VGFrame getFrame() {
return frame;
}
}
VGFrame
import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;
public class VGFrame extends JFrame {
JTextField iterations;
/**
*
*/
private static final long serialVersionUID = 1L;
public VGFrame() {
super("Vier Gewinnt");
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
this.setSize(1300, 800);
this.setVisible(true);
JPanel panelGame = new JPanel();
panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
this.add(panelGame);
var handler = new Handler();
var menuHandler = new MenuHandler();
JButton b1 = new JButton("1");
JButton b2 = new JButton("2");
JButton b3 = new JButton("3");
JButton b4 = new JButton("4");
JButton b5 = new JButton("5");
JButton b6 = new JButton("6");
JButton b7 = new JButton("7");
b1.addActionListener(handler);
b2.addActionListener(handler);
b3.addActionListener(handler);
b4.addActionListener(handler);
b5.addActionListener(handler);
b6.addActionListener(handler);
b7.addActionListener(handler);
panelGame.add(b1);
panelGame.add(b2);
panelGame.add(b3);
panelGame.add(b4);
panelGame.add(b5);
panelGame.add(b6);
panelGame.add(b7);
JButton buttonTrain = new JButton("Train");
JButton buttonNewGame = new JButton("New Game");
JButton buttonSave = new JButton("Save Weights");
JButton buttonLoad = new JButton("Load Weights");
iterations = new JTextField("1000");
buttonTrain.addActionListener(menuHandler);
buttonNewGame.addActionListener(menuHandler);
buttonSave.addActionListener(menuHandler);
buttonLoad.addActionListener(menuHandler);
iterations.addActionListener(menuHandler);
panelGame.add(iterations);
panelGame.add(buttonTrain);
panelGame.add(buttonNewGame);
panelGame.add(buttonSave);
panelGame.add(buttonLoad);
this.validate();
}
@Override
public void paint(Graphics g) {
super.paint(g);
if (Main.current.mainGame.rows == null)
return;
var rows = Main.current.mainGame.rows;
for (int i = 0; i < rows.length; i++) {
for (int j = 0; j < rows[0].length; j++) {
if (rows[i][j] == 0)
break;
g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
}
}
}
public void update() {
}
}
class Handler implements ActionListener {
@Override
public void actionPerformed(ActionEvent event) {
if (Main.current.mainGame.playersTurn)
Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
}
}
class MenuHandler implements ActionListener {
@Override
public void actionPerformed(ActionEvent event) {
switch (event.getActionCommand()) {
case "New Game":
Main.current.newGame();
break;
case "Train":
Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
break;
case "Save Weights":
Main.current.saveWeights();
break;
case "Load Weights":
Main.current.loadWeights();
break;
}
}
}
游戏
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class Game {
int turnNumber = 0;
byte[][] rows = new byte[7][6];
boolean playersTurn = true;
int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw
public boolean isRunning() {
return this.gameState == 0;
}
public void addChip(int x, boolean player1) {
turnNumber++;
byte b = nextRow(x);
if (b == 6) {
gameState = player1 ? 2 : 1;
return;
}
rows[x][b] = (byte) (player1 ? 1 : 2);
gameState = checkWinner(x, b);
}
private byte nextRow(int x) {
for (byte i = 0; i < rows[x].length; i++) {
if (rows[x][i] == 0)
return i;
}
return 6;
}
// 0 continue, 1 Player won, 2 ai won, 3 Draw
private int checkWinner(int x, int y) {
int color = rows[x][y];
// Vertikal
if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
return rows[x][y];
// Horizontal
if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
return rows[x][y];
// Diagonal1
if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
return rows[x][y];
// Diagonal2
if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
return rows[x][y];
for (byte[] bs : rows) {
for (byte s : bs) {
if (s == 0)
return 0;
}
}
return 3; // Draw
}
private int getCount(int x, int y, int dirX, int dirY) {
int color = rows[x][y];
int count = 0;
while (true) {
x += dirX;
y += dirY;
if (x < 0 | x > 6 | y < 0 | y > 5)
break;
if (color != rows[x][y])
break;
count++;
}
return count;
}
public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
boolean player1 = true;
while (this.gameState == 0) {
float[] f = Main.rowsToInput(this.rows);
INDArray input = Nd4j.create(f);
this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
player1 = !player1;
}
}
}
答案 0 :(得分:7)
快速浏览一下,并根据对乘数变体的分析,似乎 NaN
是由 算术下溢产生的,由梯度太小导致(< em>太接近绝对 0).
这是代码中最可疑的部分:
f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
如果 rows[j][i] == 1
则存储 0f
。我不知道神经网络(甚至是 java)是如何管理的,但从数学上讲,有限大小的浮点数不能包含零。
即使您的代码会用一些额外的盐来改变 0f
,这些数组值的结果也会有变得太接近于零的风险。由于表示实数时精度有限,无法表示非常接近于零的值,因此 NaN
。
这些值有一个非常友好的名称:subnormal numbers。
<块引用>幅度小于最小法线的任何非零数 数字低于正常。
<块引用>与 IEEE 754-1985 一样,标准推荐 0 表示 NaN 信号,1 表示安静的 NaN,这样可以通过仅将此位更改为 1 来使信令 NaN 安静下来,而反过来可以产生无穷大的编码。
上面的文字在这里很重要:根据标准,您实际上是在指定一个 NaN
,其中存储了任何 0f
值。
即使名称具有误导性,Float.MIN_VALUE
也是一个正值,高于 0:
真实最小float
值实际上是:-Float.MAX_VALUE
。
Is floating point math subnormal?
如果您检查问题仅仅是因为 0f
值,您可以将它们更改为代表类似内容的其他值; Float.MIN_VALUE
、Float.MIN_NORMAL
等。像这样的事情,也在可能发生这种情况的代码的其他可能部分。仅以这些为例,并使用这些范围:
rows[j][i] == 1 ? Float.MIN_VALUE : 1f;
rows[j][i] == 1 ? Float.MIN_NORMAL : Float.MAX_VALUE/2;
rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;
即便如此,根据这些值的更改方式,这也可能导致 NaN
。
如果是这样,您应该标准化这些值。您可以尝试为此应用 GradientNormalizer。在你的网络初始化中,应该为每一层(或那些有问题的)定义这样的东西:
new NeuralNetConfiguration
.Builder()
.weightInit(WeightInit.XAVIER)
(...)
.layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this
.build())
(...)
有不同的规范化器,因此请选择最适合您的架构的规范化器,以及哪些层应包含一个。选项是:
<块引用>RenormalizeL2PerLayer
通过除以 L2 范数来重新缩放梯度 图层的所有渐变。
RenormalizeL2PerParamType
通过除以 L2 重新缩放梯度 梯度的范数,分别针对其中的每种类型的参数 层。这与 RenormalizeL2PerLayer 的不同之处在于,这里的每个 参数类型(权重、偏差等)单独标准化。为了 例如,在 MLP/FeedForward 网络中(其中 G 是梯度 向量),输出如下:
GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)
ClipElementWiseAbsoluteValue
在每个元素上剪辑渐变
基础。对于每个梯度 g,设置 g <- sign(g) max(maxAllowedValue,|g|)。
即,如果参数梯度的绝对值大于
阈值,截断它。例如,如果阈值 = 5,则值
范围 -5
ClipL2PerLayer
条件重整化。有点类似 RenormalizeL2PerLayer,当且仅当这个策略缩放梯度 如果梯度的 L2 范数(对于整个层)超过指定的 临界点。具体来说,如果 G 是层的梯度向量,则:
GOut = G if l2Norm(G) < threshold(即没有变化)GOut = 阈值 * G / l2Norm(G)
ClipL2PerParamType
条件重整化。非常 类似于 ClipL2PerLayer,但不是按层进行裁剪,而是执行 分别裁剪每个参数类型。例如在一个经常性的 神经网络、输入权重梯度、循环权重梯度和 偏置梯度都是单独裁剪的。
Here 您可以找到这些 GradientNormalizers
应用的完整示例。
答案 1 :(得分:1)
我想我终于明白了。我试图使用 deeplearning4j-ui 可视化网络,但遇到了一些不兼容的版本错误。更改版本后,我收到一个新错误,指出网络输入需要一个 2d 数组,我在互联网上发现所有版本都是如此。
所以我改变了
float[] f = new float[7 * 6];
Nd4j.create(f);
到
float[][] f = new float[1][7 * 6];
Nd4j.createFromArray(f);
NaN 值终于消失了。 @aran所以我想假设不正确的输入绝对是正确的方向。非常感谢您的帮助:)