神经网络返回 NaN 作为输出

时间:2021-02-05 01:53:36

标签: java neural-network nan deeplearning4j

我正在尝试编写我的第一个神经网络来玩连接四的游戏。 我使用 Java 和 deeplearning4j。 我试图实现一个遗传算法,但是当我训练网络一段时间时,网络的输出跳到 NaN 并且我无法分辨出我在哪里搞砸了如此严重的事情。 我将在下面发布所有 3 个类,其中 Game 是游戏逻辑和规则,VGFrame 是 UI,Main 是所有 nn 的东西。

我有一个包含 35 个神经网络的池,每次迭代我都会让最好的 5 个存活并繁殖,并稍微随机化新创建的神经网络。 为了评估网络,我让他们互相争斗,然后给胜者加分,之后输给分数。 由于我将一块石头放入已经满的列中,我希望神经网络至少能够在一段时间后按照规则玩游戏,但他们不能这样做。 我用谷歌搜索了 NaN 问题,它似乎是一个梯度梯度问题,但据我所知,这不应该发生在遗传算法中? 有什么想法可以查找错误或我的实施通常有什么问题吗?

主要

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;

public class Main {
    final int numRows = 7;
    final int numColums = 6;
    final int randSeed = 123;
    MultiLayerNetwork[] models;

    static Random random = new Random();
    private static final Logger log = LoggerFactory.getLogger(Main.class);
    final float learningRate = .8f;
    int batchSize = 64; // Test batch size
    int nEpochs = 1; // Number of training epochs
    // --
    public static Main current;
    Game mainGame = new Game();

    public static void main(String[] args) {
        current = new Main();
        current.frame = new VGFrame();
        current.loadWeights();
    }

    private VGFrame frame;
    private final double mutationChance = .05;

    public Main() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU).seed(randSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
                        .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
                .build();
        models = new MultiLayerNetwork[35];
        for (int i = 0; i < models.length; i++) {
            models[i] = new MultiLayerNetwork(conf);
            models[i].init();
        }

    }

    public void addChip(int i, boolean b) {
        if (mainGame.gameState == 0)
            mainGame.addChip(i, b);
        if (mainGame.gameState == 0) {
            float[] f = Main.rowsToInput(mainGame.rows);
            INDArray input = Nd4j.create(f);
            INDArray output = models[0].output(input);
            for (int i1 = 0; i1 < 7; i1++) {
                System.out.println(i1 + ": " + output.getDouble(i1));
            }
            System.out.println("----------------");
            mainGame.addChip(Main.getHighestOutput(output), false);
        }
        getFrame().paint(getFrame().getGraphics());
    }

    public void newGame() {
        mainGame = new Game();
        getFrame().paint(getFrame().getGraphics());
    }

    public void startTraining(int iterations) {

        // --------------------------
        for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
            System.out.println("Iteration " + gameNumber + " of " + iterations);
            float[] evaluation = new float[models.length];
            for (int i = 0; i < models.length; i++) {
                for (int j = 0; j < models.length; j++) {
                    if (i != j) {
                        Game g = new Game();
                        g.playFullGame(models[i], models[j]);
                        if (g.gameState == 1) {
                            evaluation[i] += 45;
                            evaluation[j] += g.turnNumber;
                        }
                        if (g.gameState == 2) {
                            evaluation[j] += 45;
                            evaluation[i] += g.turnNumber;
                        }
                    }
                }
            }

            float[] evaluationSorted = evaluation.clone();
            Arrays.sort(evaluationSorted);
            // keep the best 4
            int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
            for (int i = 0; i < evaluation.length; i++) {
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
                    n1 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
                    n2 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
                    n3 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
                    n4 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
                    n5 = i;
            }
            models[0] = models[n1];
            models[1] = models[n2];
            models[2] = models[n3];
            models[3] = models[n4];
            models[4] = models[n5];

            for (int i = 3; i < evaluationSorted.length; i++) {
                // random parent/keep w8ts
                double r = Math.random();
                if (r > .3) {
                    models[i] = models[random.nextInt(3)].clone();

                } else if (r > .1) {
                    models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
                }
                // Mutate
                INDArray params = models[i].params();
                models[i].setParams(mutate(params));
            }
        }
    }

    private INDArray mutate(INDArray params) {
        double[] d = params.toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < mutationChance)
                d[i] += (Math.random() - .5) * learningRate;

        }
        return Nd4j.create(d);
    }

    private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        double[] d = m1.params().toDoubleVector();
        double[] d2 = m2.params().toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < .5)
                d[i] += d2[i];
        }
        return Nd4j.create(d);
    }

    static int getHighestOutput(INDArray output) {
        int x = 0;
        for (int i = 0; i < 7; i++) {
            if (output.getDouble(i) > output.getDouble(x))
                x = i;
        }
        return x;
    }

    static float[] rowsToInput(byte[][] rows) {
        float[] f = new float[7 * 6];
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 7; j++) {
                // f[j + i * 7] = rows[j][i] / 2f;
                f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
            }
        }
        return f;
    }

    public void saveWeights() {
        log.info("Saving model");
        for (int i = 0; i < models.length; i++) {
            File resourcesDirectory = new File("src/resources/model" + i);
            try {
                models[i].save(resourcesDirectory, true);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void loadWeights() {
        if (new File("src/resources/model0").exists()) {
            for (int i = 0; i < models.length; i++) {
                File resourcesDirectory = new File("src/resources/model" + i);
                try {

                    models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
        System.out.println("col: " + models[0].params().shapeInfoToString());
    }

    public VGFrame getFrame() {
        return frame;
    }

}

VGFrame

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;

public class VGFrame extends JFrame {
    JTextField iterations;
    /**
     * 
     */
    private static final long serialVersionUID = 1L;

    public VGFrame() {
        super("Vier Gewinnt");
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setSize(1300, 800);
        this.setVisible(true);
        JPanel panelGame = new JPanel();
        panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
        this.add(panelGame);

        var handler = new Handler();
        var menuHandler = new MenuHandler();

        JButton b1 = new JButton("1");
        JButton b2 = new JButton("2");
        JButton b3 = new JButton("3");
        JButton b4 = new JButton("4");
        JButton b5 = new JButton("5");
        JButton b6 = new JButton("6");
        JButton b7 = new JButton("7");
        b1.addActionListener(handler);
        b2.addActionListener(handler);
        b3.addActionListener(handler);
        b4.addActionListener(handler);
        b5.addActionListener(handler);
        b6.addActionListener(handler);
        b7.addActionListener(handler);
        panelGame.add(b1);
        panelGame.add(b2);
        panelGame.add(b3);
        panelGame.add(b4);
        panelGame.add(b5);
        panelGame.add(b6);
        panelGame.add(b7);

        JButton buttonTrain = new JButton("Train");
        JButton buttonNewGame = new JButton("New Game");
        JButton buttonSave = new JButton("Save Weights");
        JButton buttonLoad = new JButton("Load Weights");

        iterations = new JTextField("1000");

        buttonTrain.addActionListener(menuHandler);
        buttonNewGame.addActionListener(menuHandler);
        buttonSave.addActionListener(menuHandler);
        buttonLoad.addActionListener(menuHandler);
        iterations.addActionListener(menuHandler);

        panelGame.add(iterations);
        panelGame.add(buttonTrain);
        panelGame.add(buttonNewGame);
        panelGame.add(buttonSave);
        panelGame.add(buttonLoad);

        this.validate();
    }

    @Override
    public void paint(Graphics g) {
        super.paint(g);
        if (Main.current.mainGame.rows == null)
            return;
        var rows = Main.current.mainGame.rows;
        for (int i = 0; i < rows.length; i++) {
            for (int j = 0; j < rows[0].length; j++) {
                if (rows[i][j] == 0)
                    break;

                g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
                g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
            }
        }
    }

    public void update() {
    }
}

class Handler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        if (Main.current.mainGame.playersTurn)
            Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
    }
}

class MenuHandler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        switch (event.getActionCommand()) {
        case "New Game":
            Main.current.newGame();
            break;
        case "Train":
            Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
            break;
        case "Save Weights":
            Main.current.saveWeights();
            break;
        case "Load Weights":
            Main.current.loadWeights();
            break;
        }

    }
}

游戏

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Game {

    int turnNumber = 0;
    byte[][] rows = new byte[7][6];
    boolean playersTurn = true;

    int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw

    public boolean isRunning() {
        return this.gameState == 0;
    }

    public void addChip(int x, boolean player1) {
        turnNumber++;
        byte b = nextRow(x);
        if (b == 6) {
            gameState = player1 ? 2 : 1;
            return;
        }
        rows[x][b] = (byte) (player1 ? 1 : 2);
        gameState = checkWinner(x, b);
    }

    private byte nextRow(int x) {
        for (byte i = 0; i < rows[x].length; i++) {
            if (rows[x][i] == 0)
                return i;
        }
        return 6;
    }

    // 0 continue, 1 Player won, 2 ai won, 3 Draw
    private int checkWinner(int x, int y) {
        int color = rows[x][y];
        // Vertikal
        if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
            return rows[x][y];

        // Horizontal
        if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
            return rows[x][y];

        // Diagonal1
        if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
            return rows[x][y];
        // Diagonal2
        if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
            return rows[x][y];
        
        for (byte[] bs : rows) {
            for (byte s : bs) {
                if (s == 0)
                    return 0;
            }
        }
        return 3; // Draw
    }

    private int getCount(int x, int y, int dirX, int dirY) {
        int color = rows[x][y];
        int count = 0;
        while (true) {
            x += dirX;
            y += dirY;
            if (x < 0 | x > 6 | y < 0 | y > 5)
                break;
            if (color != rows[x][y])
                break;
            count++;
        }
        return count;
    }

    public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        boolean player1 = true;
        while (this.gameState == 0) {
            float[] f = Main.rowsToInput(this.rows);
            INDArray input = Nd4j.create(f);
            this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
            player1 = !player1;
        }
    }
}

2 个答案:

答案 0 :(得分:7)

快速浏览一下,并根据对乘数变体的分析,似乎 NaN 是由 算术下溢产生的,由梯度太小导致(< em>太接近绝对 0).

这是代码中最可疑的部分:

 f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);

如果 rows[j][i] == 1 则存储 0f。我不知道神经网络(甚至是 java)是如何管理的,但从数学上讲,有限大小的浮点数不能包含零

即使您的代码会用一些额外的盐来改变 0f,这些数组值的结果也会有变得太接近于零的风险。由于表示实数时精度有限,无法表示非常接近于零的值,因此 NaN

这些值有一个非常友好的名称:subnormal numbers

<块引用>

幅度小于最小法线的任何非零数 数字低于正常

enter image description here

IEEE_754

<块引用>

与 IEEE 754-1985 一样,标准推荐 0 表示 NaN 信号,1 表示安静的 NaN,这样可以通过仅将此位更改为 1 来使信令 NaN 安静下来,而反过来可以产生无穷大的编码。

上面的文字在这里很重要:根据标准,您实际上是在指定一个 NaN,其中存储了任何 0f 值。


即使名称具有误导性,Float.MIN_VALUE 也是一个值,高于 0

enter image description here

真实最小float值实际上是:-Float.MAX_VALUE

Is floating point math subnormal?


标准化梯度

如果您检查问题仅仅是因为 0f 值,您可以将它们更改为代表类似内容的其他值; Float.MIN_VALUEFloat.MIN_NORMAL 等。像这样的事情,也在可能发生这种情况的代码的其他可能部分。仅以这些为例,并使用这些范围:

rows[j][i] == 1 ? Float.MIN_VALUE : 1f;

rows[j][i] == 1 ?  Float.MIN_NORMAL : Float.MAX_VALUE/2;

rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;

即便如此,根据这些值的更改方式,这也可能导致 NaN。 如果是这样,您应该标准化这些值。您可以尝试为此应用 GradientNormalizer。在你的网络初始化中,应该为每一层(或那些有问题的)定义这样的东西:

new NeuralNetConfiguration
  .Builder()
  .weightInit(WeightInit.XAVIER)
  (...)
  .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
        .weightInit(WeightInit.XAVIER)
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this   
        .build())
  
  (...)

有不同的规范化器,因此请选择最适合您的架构的规范化器,以及哪些层应包含一个。选项是:

<块引用>

GradientNormalization

  • RenormalizeL2PerLayer

    通过除以 L2 范数来重新缩放梯度 图层的所有渐变。

  • RenormalizeL2PerParamType

    通过除以 L2 重新缩放梯度 梯度的范数,分别针对其中的每种类型的参数 层。这与 RenormalizeL2PerLayer 的不同之处在于,这里的每个 参数类型(权重、偏差等)单独标准化。为了 例如,在 MLP/FeedForward 网络中(其中 G 是梯度 向量),输出如下:

    GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)

  • ClipElementWiseAbsoluteValue

    在每个元素上剪辑渐变 基础。对于每个梯度 g,设置 g <- sign(g) max(maxAllowedValue,|g|)。 即,如果参数梯度的绝对值大于 阈值,截断它。例如,如果阈值 = 5,则值 范围 -55 是 设置为 5。

  • ClipL2PerLayer

    条件重整化。有点类似 RenormalizeL2PerLayer,当且仅当这个策略缩放梯度 如果梯度的 L2 范数(对于整个层)超过指定的 临界点。具体来说,如果 G 是层的梯度向量,则:

    GOut = G if l2Norm(G) < threshold(即没有变化)GOut = 阈值 * G / l2Norm(G)

  • ClipL2PerParamType

    条件重整化。非常 类似于 ClipL2PerLayer,但不是按层进行裁剪,而是执行 分别裁剪每个参数类型。例如在一个经常性的 神经网络、输入权重梯度、循环权重梯度和 偏置梯度都是单独裁剪的。


Here 您可以找到这些 GradientNormalizers 应用的完整示例。

答案 1 :(得分:1)

我想我终于明白了。我试图使用 deeplearning4j-ui 可视化网络,但遇到了一些不兼容的版本错误。更改版本后,我收到一个新错误,指出网络输入需要一个 2d 数组,我在互联网上发现所有版本都是如此。

所以我改变了

float[] f = new float[7 * 6];
Nd4j.create(f);

float[][] f = new float[1][7 * 6];
Nd4j.createFromArray(f);

NaN 值终于消失了。 @aran所以我想假设不正确的输入绝对是正确的方向。非常感谢您的帮助:)