反向传播神经网络无效

时间:2016-09-20 15:10:09

标签: java matrix neural-network backpropagation gradient-descent

我刚开始学习机器学习和神经网络,所以我仍然在努力理解反向传播是如何工作的。 我尝试使用简单的基于矩阵的方法在java中开发一个简单的NN。如果我只提供一个训练示例,网络就能完美运行,但如果我尝试使用更多,则输出始终是训练所需输出的平均值。 http://neuralnetworksanddeeplearning.com/images/tikz21.png

package neuralnetwork;
/**
 * @author Paolo Pellizzoni
 */

public class NeuralNetwork {

static final int in_l = 2; 
static final int h_l = 5; 
static final int out_l = 1; 

public static double[][] w2 = new double[h_l][in_l];
public static double[][] w3 = new double[out_l][h_l];
public static double[] b2 = new double[h_l];
public static double[] b3 = new double[out_l];

public static double[][] x = {{3,4},{2,3}};
public static double[][] y = {{0.3,0.7}};
public static double[][] test = {{3}, {2}};
// using x = {{3},{2}} and y = {{0.3}} it works

    public static void main(String[] args) {
        trainNN(0.2);
        double[][] m = a_3(test);

        for(int i=0; i<m.length; i++){
            for(int j=0; j<m[0].length; j++){
        System.out.print(m[i][j]+" ");
            }
            System.out.println();
    }
    }
    // ---------- FUNCTIONS ----------

    static void inizialize_weights(double[][] m){
    for(int i=0; i<m.length; i++){
            for(int j=0; j<m[0].length; j++){
        m[i][j]= Math.random();
            }
    }
    }
    static void trainNN(double rate){
        inizialize_weights(w2);
        inizialize_weights(w3);

        for(int c=0; c<500; c++){
            double[][] dJ_w3 = dJ_w3(x, y);
            double[][] dJ_w2 = dJ_w2(x, y);
            double[] dJ_b3 = dJ_b3(x, y);
            double[] dJ_b2 = dJ_b2(x, y);
            w3 = matrix_sum(w3, dJ_w3, -rate);
            w2 = matrix_sum(w2, dJ_w2, -rate);
            b3 = vect_sum(b3, dJ_b3, -rate);
            b2 = vect_sum(b2, dJ_b2, -rate);
        }
    }

    static double[][] a_3(double[][] inputs){
        return sigmoid(z_3(inputs));
    }
    static double[][] z_3(double[][] inputs){
        return matrix_sum_vect(matrix_product(w3, a_2(inputs)), b3, 1);
    }
    static double[][] a_2(double[][] inputs){
        return sigmoid(z_2(inputs));
    }
    static double[][] z_2(double[][] inputs){
        return matrix_sum_vect(matrix_product(w2, inputs), b2, 1);
    }

    static double[][] delta3 (double[][] inputs, double[][] y){
        return matrix_hadamard(
                matrix_sum(a_3(inputs), y, -1),
                sigmoid_prime(z_3(inputs))
        );
    }
    static double[][] delta2 (double[][] inputs, double[][] y){
        return matrix_hadamard(
                matrix_product(
                        transpose_matrix(w3),
                        delta3(inputs, y)),
                sigmoid_prime(z_2(inputs))
        );
    }
    static double[][] dJ_w3 (double[][] inputs, double[][] y){
        double[][] dJ_w3 = new double[out_l][h_l];
        double[][] delta3 = delta3(inputs, y);
        double[][] a2 = a_2(inputs);
        for(int i=0; i<delta3.length; i++){
            for(int j=0; j<a2.length; j++){
                double tmp = 0;
                for(int k=0; k<a2[0].length; k++){
                    tmp += a2[j][k]*delta3[i][k];
                }
                dJ_w3[i][j] = tmp/a2[0].length;
            }
        }

        return dJ_w3;
    }
    static double[][] dJ_w2 (double[][] inputs, double[][] y){
        double[][] dJ_w2 = new double[h_l][in_l];
        double[][] delta2 = delta2(inputs, y);
        double[][] a1 = inputs;

        for(int i=0; i<delta2.length; i++){
            for(int j=0; j<a1.length; j++){
                double tmp = 0;
                for(int k=0; k<a1[0].length; k++){
                    tmp += a1[j][k]*delta2[j][k];
                }
                dJ_w2[i][j] = tmp/a1[0].length;
            }
        }

        return dJ_w2;
    }
    static double[] dJ_b3 (double[][] inputs, double[][] y){
        double[] dJ_b3 = new double[out_l];
        double[][] delta3 = delta3(inputs, y);
        for(int i=0; i<delta3.length; i++){
            double tmp = 0;
            for(int k=0; k<delta3[0].length; k++){
                tmp += delta3[i][k];
            }
            dJ_b3[i] = tmp/delta3[0].length;
        }

        return dJ_b3;
    }
    static double[] dJ_b2 (double[][] inputs, double[][] y){
        double[] dJ_b2 = new double[h_l];
        double[][] delta2 = delta2(inputs, y);
        for(int i=0; i<delta2.length; i++){
            double tmp = 0;
            for(int k=0; k<delta2[0].length; k++){
                tmp += delta2[i][k];
            }
            dJ_b2[i] = tmp/delta2[0].length;
        }

        return dJ_b2;
    }


    // ----- Math -----


    static double[][] matrix_product(double[][] a, double[][] b){  // matrix multiplication
        int m1ColLength = a[0].length; 
        int m2RowLength = b.length;   
        if(m1ColLength != m2RowLength) return null; 
        int mRRowLength = a.length;    
        int mRColLength = b[0].length; 
        double[][] mResult = new double[mRRowLength][mRColLength];
        for(int i = 0; i < mRRowLength; i++) {         
            for(int j = 0; j < mRColLength; j++) {     
                for(int k = 0; k < m1ColLength; k++) { 
                    mResult[i][j] += a[i][k] * b[k][j];
                }
            }
        }
        return mResult;
    }
    static double[][] matrix_sum(double[][] a, double[][] b, double is_sum){ //matrix sum
        int m1ColLength = a[0].length; 
        int m2RowLength = b.length;    
        int m1RowLength = a.length;    
        int m2ColLength = b[0].length; 
        if(m1ColLength != m2ColLength || m1RowLength != m2RowLength) return null;
        double[][] mResult = new double[m1RowLength][m1ColLength];
        for(int i = 0; i < m1RowLength; i++) {         
            for(int j = 0; j < m1ColLength; j++) {     
                mResult[i][j]=a[i][j]+(b[i][j])*is_sum;
            }
        }
        return mResult;
    }
    static double[] vect_sum(double[] a, double[] b, double is_sum){ // vector sum
        int m2RowLength = b.length;    
        int m1RowLength = a.length;     
        if(m1RowLength != m2RowLength) return null;
        double[] mResult = new double[m1RowLength];
        for(int i = 0; i < m1RowLength; i++) {         
            mResult[i]=a[i]+(b[i])*is_sum;
        }
        return mResult;
    }
    static double[][] matrix_sum_vect(double[][] a, double[] b, double is_sum){ // adds a vector to each column
        int m1ColLength = a[0].length; 
        int m2RowLength = b.length;    
        int m1RowLength = a.length;    
        if(m1RowLength != m2RowLength) return null;
        double[][] mResult = new double[m1RowLength][m1ColLength];
        for(int i = 0; i < m1RowLength; i++) {        
            for(int j = 0; j < m1ColLength; j++) {    
                mResult[i][j]=a[i][j]+(b[i])*is_sum;
            }
        }
        return mResult;
    }
    static double[][] matrix_hadamard(double[][] a, double[][] b){ // hadamard product
        int m1ColLength = a[0].length; 
        int m2RowLength = b.length;    
        int m1RowLength = a.length;    
        int m2ColLength = b[0].length; 
        if(m1ColLength != m2ColLength || m1RowLength != m2RowLength) return null;
        double[][] mResult = new double[m1RowLength][m1ColLength];
        for(int i = 0; i < m1RowLength; i++) {         
            for(int j = 0; j < m1ColLength; j++) {    
                mResult[i][j]=a[i][j]*b[i][j];
            }
        }
        return mResult;
    }
    static double[][] matrix_x_scalar(double[][] a, double scalar){ // matrix times scalar
        int m1ColLength = a[0].length;   
        int m1RowLength = a.length;    
        double[][] mResult = new double[m1RowLength][m1ColLength];
        for(int i = 0; i < m1RowLength; i++) {         
            for(int j = 0; j < m1ColLength; j++) {    
                mResult[i][j]=a[i][j]*scalar;
            }
        }
        return mResult;
    }
    static double[][] transpose_matrix(double [][] m){
        double[][] mResult = new double[m[0].length][m.length];
        for (int i = 0; i < m.length; i++)
            for (int j = 0; j < m[0].length; j++)
                mResult[j][i] = m[i][j];
        return mResult;
    }
    static double sigmoid(double z) {
    return 1.0/(1.0+Math.exp(-z));
    }
    static double[][] sigmoid(double[][] z) {
        for(int i=0; i<z.length; i++){
            for(int j=0; j<z[0].length; j++){
                z[i][j]= sigmoid(z[i][j]);
            }
        }
    return z;
    }
    static double sigmoid_prime(double z) {
    return sigmoid(z)*(1-sigmoid(z));
    }
    static double[][] sigmoid_prime(double[][] z) {
        for(int i=0; i<z.length; i++){
            for(int j=0; j<z[0].length; j++){
                z[i][j]= sigmoid_prime(z[i][j]);
            }
        }
    return z;
    }// ----- end math -----







}

我很确定错误隐藏在dJ_w3, dJ_w2函数中,可能是在平均所有渐变的k循环中,但我找不到它。 你能救我吗?

1 个答案:

答案 0 :(得分:0)

发现问题,我只需要将训练迭代次数增加到50000次。