神经网络ReLU

时间:2019-12-08 03:42:30

标签: java tensorflow neural-network artificial-intelligence gradient-descent

我试图通过对自定义神经网络进行编码来了解神经网络的工作原理。

它在使用S形函数时确实工作得很好,但是当我改用ReLU时,我的神经网络似乎无法正常工作。学习速度很慢,通常只输出0。

这是我的激活功能代码:

private Double reLu(Double a) {
    if (a >= 0.0)
        return a;
    else
        return 0.01 * a;
}

private Double dReLu(Double a) {
    if (a >= 0.0)
        return 1.0;
    else
        return 0.01;
}

private Double sigmoid(Double a) {
    return 1 / (1 + Math.pow(2.71828, -1.0 * a));
}

private Double dsigmoid(Double a) {
    return Math.pow(2.71828, -1.0 * a) / Math.pow(Math.pow(2.71828, -1.0 * a) + 1, 2);
}

以及用于预测,学习功能的代码

(neurons [] []仅用于隐藏层/ weights [i] [j] [l]表示(neurons [i-1] [j]到神经元[i] [l]的权重)/和bias [ ] []用于隐藏层,输出/和input [],output []很明显)

    public Double[] predict(Double[] input) {
        if (inputs.length != input.length)
            return new Double[]{-1.0};

        else {
            for (int i = 0; i < sums.length; i++) {
                for (int j = 0; j < sums[i].length; j++) {
                    sums[i][j] = 0.0;
                }
            }
            System.arraycopy(input, 0, inputs, 0, input.length);
            for (int i = 0; i < neurons.length + 1; i++) {
                if (i == 0) {
                    for (int j = 0; j < neurons[0].length; j++) {
                        for (int l = 0; l < inputs.length; l++) {
                            sums[0][j] += weights[0][l][j] * inputs[l];
                        }

                        sums[0][j] += bias[0][j];

                        neurons[0][j] = reLu(sums[0][j]);
                    }
                }

                else if (i == neurons.length) {
                    for (int j = 0; j < outputs.length; j++) {
                        for (int l = 0; l < neurons[i - 1].length; l++) {
                            sums[i][j] += weights[i][l][j] * neurons[i - 1][l];
                        }
                        sums[i][j] += bias[i][j];
                        outputs[j] = sigmoid(sums[i][j]);
                    }
                }

                else {
                    for (int j = 0; j < neurons[i].length; j++) {
                        for (int l = 0; l < neurons[i - 1].length; l++) {
                            sums[i][j] += weights[i][l][j] * neurons[i - 1][l];
                        }
                        sums[i][j] += bias[i][j];
                        neurons[i][j] = reLu(sums[i][j]);
                    }
                }
            }
        }

        return this.outputs;        
    }

(dCost [i] [j]是Neurons [i] [j]相对于Cost /的布尔值,布尔dosSum度量如果将平均权重与实际重量相加,则与biasNudge相同/ howManyTimePassed是变量,用于计算批量大小/ multyPlyer只是LearningRate)

public void learn(Double[] Desired, boolean doesSum, Double multyPlyer) {
        Double[][] dCost = new Double[neurons.length + 1][];

        for (int i = neurons.length; i >= 0; i--) {
            if (i == neurons.length) {
                dCost[i] = new Double[outputs.length];
                for (int j = 0; j < outputs.length; j++) {
                    dCost[i][j] = 2 * (outputs[j] - Desired[j]);
                }
            }

            else if (i == neurons.length - 1) {
                Double temp = 0.0;
                dCost[i] = new Double[neurons[i].length];
                for (int j = 0; j < neurons[i].length; j++) {
                    for (int l = 0; l < outputs.length; l++) {
                        temp += weights[i + 1][j][l] * dsigmoid(sums[i + 1][l]) * dCost[i + 1][l];
                    }

                    dCost[i][j] = temp;
                }
            }

            else {
                Double temp = 0.0;
                dCost[i] = new Double[neurons[i].length];
                for (int j = 0; j < neurons[i].length; j++) {
                    for (int l = 0; l < neurons[i + 1].length; l++) {
                        temp += weights[i + 1][j][l] * dReLu(sums[i + 1][l]) * dCost[i + 1][l];
                    }

                    dCost[i][j] = temp;
                }
            }
        }

        for (int i = 0; i < neurons.length + 1; i++) {
            if (i == 0) {
                for (int j = 0; j < inputs.length; j++) {
                    for (int l = 0; l < neurons[i].length; l++) {
                        weightsNudge[i][j][l] -= inputs[j] * dReLu(sums[i][l]) * dCost[i][l] * multyPlyer;
                    }
                }
            }


            else if (i == neurons.length) {
                for (int j = 0; j < neurons[i - 1].length; j++) {
                    for (int l = 0; l < outputs.length; l++) {
                        weightsNudge[i][j][l] -= neurons[i - 1][j] * dsigmoid(sums[i][l]) * dCost[i][l] * multyPlyer;
                    }
                }
            }

            else {
                for (int j = 0; j < neurons[i - 1].length; j++)  {
                    for (int l = 0; l < neurons[i].length; l++) {
                        weightsNudge[i][j][l] -= neurons[i - 1][j] * dReLu(sums[i][l]) * dCost[i][l] * multyPlyer;
                    }
                }
            }
        }

        for (int i = 0; i < neurons.length + 1; i++) {
            for (int j = 0; j < sums[i].length; j++) {
                if (i == neurons.length)
                    biasNudge[i][j] -= dsigmoid(sums[i][j]) * dCost[i][j] * multyPlyer;
                else
                    biasNudge[i][j] -= dReLu(sums[i][j]) * dCost[i][j] * multyPlyer;
            }
        }

        howManyTimePassed++;

        if (doesSum) {
            for (int i = 0; i < weights.length; i++) {
                for (int j = 0; j < weights[i].length; j++) {
                    for (int l = 0; l < weights[i][j].length; l++) {
                        weights[i][j][l] += weightsNudge[i][j][l] / howManyTimePassed;
                        weightsNudge[i][j][l] = 0.0;
                    }
                }
            }

            for (int i = 0; i < bias.length; i++) {
                for (int j = 0; j < bias[i].length; j++) {
                    bias[i][j] += biasNudge[i][j] / howManyTimePassed * multyPlyer;
                    biasNudge[i][j] = 0.0;
                }
            }

            howManyTimePassed = 0;
        }
    }

有人可以阅读下面的代码并发现任何愚蠢的东西吗?拜托。

0 个答案:

没有答案