部分语音分类问题-神经网络无法学习

时间:2019-05-05 18:09:04

标签: java neural-network backpropagation

我正在编写一个NN,对波兰语的词性进行分类。当我启动神经网络时,我注意到权重不断增加并隐藏,错误(成本)达到最大而不是最小化。

这是我的网络课程:


    import java.io.*;
    import java.lang.Math;
    import java.util.Random;
    import java.text.DecimalFormat;

    public class NeuralNetwork {

        // constructor generates random weights and trains the neural network
        public NeuralNetwork() {
            randomise();
            System.out.println("The network has been initialized with random weights.");
            train();
            System.out.println("Weights have been adjusted and the network is trained.\nProceeding to classification.");
        }

        // format of decimal values to be printed out onto the console
        private DecimalFormat df = new DecimalFormat("#.##");

        // boolean value - if true print system logs
        private boolean printLogs = false;

        // maximal accepted length of a word = 30 letters
        private double[] input = new double[32 * 32 + 20];

        // there are 10 parts of speech in the polish language
        private double[] output = new double[10];

        // abstract decision to have 20 nodes in the hidden layer -> (30 + 10) / 2
        private double[] hiddenLayer = new double[20];

        // 20 nodes in hidden layer x 30 input nodes -> each input times it's wieght is an input for each hidden layer node
        private double[][] weightsForInput = new double [20][32 * 32 + 20];

        // 10 output nodes x 20 hidden layer nodes -> each hidden layer node times it's wieght is an input for each output node
        private double[][] weightsForHiddenLayer = new double[10][20];

        // learning rate
        private double learningRate = 0.1;

    public void generateInputFromWord(String word, int placeInSentence, int sentenceLength) {
        boolean first;
        for (int c = 0; c < word.length(); c++) {
            first = true;
            for (int x = 0; x < 32; x++) {
                if (first && x == convertCharToInt(word.charAt(c)) - 1) {
                    input[32 * c + x] =  1.;
                    first = false;
                }
            }
        }
        input[32 * 32 + (int) Math.round(20. * (double) placeInSentence / sentenceLength) - 1] = 1.;
    }

        // creates output where each node = sigmoid(sum(hiddenLayerNode * wieght))
        private void generateOutput() {
            for(int i = 0; i < output.length; i++) {
                double sum = 0;
                for(int j = 0; j < hiddenLayer.length; j++) {
                    sum += hiddenLayer[j] * weightsForHiddenLayer[i][j];
                }
                output[i] = sigmoid(sum);
            }
        }

        // creates hidden layer where each node = sigmoid(sum(inputNode * weight))
        private void generateHiddenLayer() {
            for(int i = 0; i < hiddenLayer.length; i++) {
                double sum = 0;
                for(int j = 0; j < input.length; j++) {
                    sum += input[j] * weightsForInput[i][j];
                }
                hiddenLayer[i] = sigmoid(sum);
            }
        }

        // returns an integer from 1 to 32 for each letter of the polish alphabet
        private int convertCharToInt(char c){
            switch(c){
                case 'a': return 1;
                case '\u0105': return 2;
                case 'b': return 3;
                case 'c': return 4;
                case '\u0107': return 5;
                case 'd': return 6;
                case 'e': return 7;
                case '\u0119': return 8;
                case 'f': return 9;
                case 'g': return 10;
                case 'h': return 11;
                case 'i': return 12;
                case 'j': return 13;
                case 'k': return 14;
                case 'l': return 15;
                case '\u0142': return 16;
                case 'm': return 17;
                case 'n': return 18;
                case '\u0144': return 19;
                case 'o': return 20;
                case '\u00F3': return 21;
                case 'p': return 22;
                case 'r': return 23;
                case 's': return 24;
                case '\u015B': return 25;
                case 't': return 26;
                case 'u': return 27;
                case 'w': return 28;
                case 'y': return 29;
                case 'z': return 30;
                case '\u017A': return 31;
                case '\u017C': return 32;
                default: return 0;
            }
        }

        // populate wieghts and the bias with random values
        private void randomise() {
            Random random = new Random();

            for (int i = 0; i < weightsForInput.length; i++) {
                for (int j = 0; j < weightsForInput[0].length; j++) {
                    weightsForInput[i][j] = random.nextDouble();
                }
            }

            for (int i = 0; i < weightsForHiddenLayer.length; i++) {
                for (int j = 0; j < weightsForHiddenLayer[0].length; j++) {
                    weightsForHiddenLayer[i][j] = random.nextDouble();
                }
            }
        }

        // sigmoid function 1 / (1 + e ^ (- x)) returns num in range (0, 1)
        private double sigmoid(double input) {
            return 1 / (1 + Math.pow(Math.E, - input));
        }

        // converts PoS String into size[10] array of classification output
        private double[] generateExpectedOutput(String input) {
            double[] result = new double[10];

            switch(input) {
                case "RZECZOWNIK":
                    result[0] = 1;
                    break;
                case "PRZYMIOTNIK":
                    result[1] = 1;
                    break;
                case "LICZEBNIK":
                    result[2] = 1;
                    break;
                case "ZAIMEK":
                    result[3] = 1;
                    break;
                case "CZASOWNIK":
                    result[4] = 1;
                    break;
                case "PRZYSLOWEK":
                    result[5] = 1;
                    break;
                case "PRZYIMEK":
                    result[6] = 1;
                    break;
                case "SPOJNIK":
                    result[7] = 1;
                    break;
                case "WYKRZYKNIK":
                    result[8] = 1;
                    break;
                case "PARTYKULA":
                    result[9] = 1;
                    break;
            }

            return result;
        }

        // convert output array into one of 10 possible parts of speech
        private POS getPOSFromOutput() {
            int node = 0;
            double max = 0;
            POS result = POS.RZECZOWNIK;
            for(int i = 0; i < output.length; i++) {
                if(output[i] > max) {
                    max = output[i];
                    node = i;
                }
            }
            switch(node) {
                case 0:
                    result = POS.RZECZOWNIK;
                    break;
                case 1:
                    result = POS.PRZYMIOTNIK;
                    break;
                case 2:
                    result = POS.LICZEBNIK;
                    break;
                case 3:
                    result = POS.ZAIMEK;
                    break;
                case 4:
                    result = POS.CZASOWNIK;
                    break;
                case 5:
                    result = POS.PRZYSLOWEK;
                    break;
                case 6:
                    result = POS.PRZYIMEK;
                    break;
                case 7:
                    result = POS.SPOJNIK;
                    break;
                case 8:
                    result = POS.WYKRZYKNIK;
                    break;
                case 9:
                    result = POS.PARTYKULA;
                    break;  
            }

            return result;
        }

        // prints a 1D array onto the console
        private void printOneDArray(double[] array, String arrayName) {
            System.out.println(arrayName + ":");
            for(int i = 0; i < array.length; i++) {
                System.out.print(df.format(array[i]) + " ");
            }
            System.out.println("");
        }

        // prints a 2D array onto the console
        private void printTwoDArray(double[][] array, String arrayName) {
            System.out.println(arrayName + ":");
            for(int x = 0; x < array.length; x++) {
                for(int y = 0; y < array[x].length; y++) {
                    System.out.print(df.format(array[x][y]) + " ");
                }
                System.out.println("");
            }
            System.out.println("");
        }

        // calculates average error where each output node error = (output - expectedOutput)^2
        private double calculateError(double[] expectedOutput) {
            double error = 0;
            for(int i = 0; i < output.length; i++) {
                // calculate sum of quadratic difference
                error += Math.pow(expectedOutput[i] - output[i], 2);
            }
            // average the sum
            return error / expectedOutput.length;
        }

        // calculates new weights for the hidden layer based on current weights and calculated error for a specific training element
        private double[][] calculateNewWeightsForHiddenLayer(double[] expectedOutput) {
            double[][] newWeights = new double[weightsForHiddenLayer.length][weightsForHiddenLayer[0].length];
            for(int i = 0; i < expectedOutput.length; i++) {
                for(int x = 0; x < weightsForHiddenLayer.length; x++) {
                    for(int y =0; y < weightsForHiddenLayer[x].length; y++) {
                        // derivate (weight * hiddenLayer) in respect to weight * derivative (sigmoid(x)) in respect to x * derivative (output - expectedOutput)^2 in respect to output
                        newWeights[x][y] += weightsForHiddenLayer[x][y] - learningRate *
                                            hiddenLayer[y] *
                                            (sigmoid(weightsForHiddenLayer[x][y] * hiddenLayer[y]) * (1 - sigmoid(weightsForHiddenLayer[x][y] * hiddenLayer[y]))) *
                                            2 * (expectedOutput[i] - output[i]);
                    }
                }
            }
            newWeights = calculateAverageArray(newWeights, expectedOutput.length);
            return newWeights;
        }

        // calculates new weights for the input layer based on current weights and calculated error for a specific training element
        private double[][] calculateNewWeightsForInput(double[] expectedOutput) {
            double[][] newWeights = new double[weightsForInput.length][weightsForInput[0].length];
            for(int i = 0; i < expectedOutput.length; i++) {
                for(int x = 0; x < weightsForInput.length; x++) {
                    for(int y = 0; y < weightsForInput[x].length; y++) {
                        // derivative function
                        newWeights[x][y] += weightsForInput[x][y] - learningRate *
                                            input[y] *
                                            (sigmoid(weightsForHiddenLayer[i][x] * hiddenLayer[x]) * (1 - sigmoid(weightsForHiddenLayer[i][x] * hiddenLayer[x]))) *
                                            weightsForHiddenLayer[i][x] *
                                            (sigmoid(weightsForInput[x][y] * input[y]) * (1 - sigmoid(weightsForInput[x][y] * input[y]))) *
                                            2 * (expectedOutput[i] - output[i]);
                    }
                }
            }
            newWeights = calculateAverageArray(newWeights, expectedOutput.length);
            return newWeights;
        }

        // adds two arrays of same size together
        private double[][] addArrays(double[][] arrayOne, double arrayTwo[][]) {
            double sumArray[][] = new double[arrayOne.length][arrayOne[0].length];
            for(int x = 0; x < arrayOne.length; x++) {
                for(int y =0; y < arrayOne[x].length; y++) {
                    sumArray[x][y] = arrayOne[x][y] + arrayTwo[x][y];
                }
            }
            return sumArray;
        }

        // calculates the average for a layer by dividing each element by the amount of elements used to create sum
        private double[][] calculateAverageArray(double[][] array, int counter) {
            double[][] averageArray = new double[array.length][array[0].length];
            for(int x = 0; x < array.length; x++) {
                for(int y= 0; y < array[x].length; y++) {
                    averageArray[x][y] = array[x][y] / counter;
                }
            }
            return averageArray;
        }

        // trains the neural network based on a training set in the file training_set.txt
        private void train() {
            try {
                // classification error
                double error = 0;
                double errorSum = 0;
                // count which element is being calculated
                int counter = 0;
                int counterSum = 0;
                // train as long as the error is above 10%
                do{
                    if(printLogs) {
                        printTwoDArray(weightsForInput, "Weights for input");
                        printTwoDArray(weightsForHiddenLayer, "Weights for hidden layer");
                    }
                    // input file with text to be classified
                    File file = new File("training_set.txt"); 
                    BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF8"));
                    // input line of text
                    String line = "";
                    // reseting counter and error sum (for average error calculation) for each new passage through the training set
                    counter = 0;
                    errorSum = 0;
                    // array with temporary weightsForInput
                    double[][] tmpInputWeights = new double[20][32 * 32 + 20];
                    // array for temporary weightsForHiddenLayer
                    double[][] tmpHiddenLayerWeights = new double[10][20];
                    // loop to train each element in the training set
                    while ((line = br.readLine()) != null) {
                        counter++;
                        // [0] Word [1] place in sentence [2] sentence length [3] PoS
                        String[] data = line.split(",");
                        generateInputFromWord(data[0], Integer.parseInt(data[1]), Integer.parseInt(data[2]));
                        generateHiddenLayer();
                        generateOutput();
                        error = calculateError(generateExpectedOutput(data[3]));
                        errorSum += error;
                        if(printLogs) {
                            System.out.println(counter + " training element: " + data[0] + " " + data[1] + " " + data[2] + " " + data[3]);
                            printOneDArray(input, "Input layer");
                            printOneDArray(hiddenLayer, "Hidden layer");
                            printOneDArray(output, "Output layer");
                            printOneDArray(generateExpectedOutput(data[3]), "Expected output");
                            System.out.println("\n" + "Error: " + error + "\n");
                        }
                        tmpHiddenLayerWeights = addArrays(tmpHiddenLayerWeights, calculateNewWeightsForHiddenLayer(generateExpectedOutput(data[3])));
                        tmpInputWeights = addArrays(tmpInputWeights, calculateNewWeightsForInput(generateExpectedOutput(data[3])));
                        for(int i = 0; i < input.length; i++) {
                            input[i] = 0;
                        }
                    }
                    counterSum += counter;
                    br.close();
                    // calculate average weights
                    weightsForHiddenLayer = calculateAverageArray(tmpHiddenLayerWeights, counter);
                    weightsForInput = calculateAverageArray(tmpInputWeights, counter);
                    // calculate average error over all the training set
                    errorSum /= counter;
                    if(printLogs) {
                        System.out.println("\n" + "Average error: " + errorSum + "\n");
                    }
                } while(errorSum > 0.1);
            } catch (Exception e) {
                System.out.println("Error - main: " + e.getMessage());
                e.printStackTrace();
            }
        }

        // classify a word as one of 10 possible parts of speech
        public POS classify(String word, int placeInSentence, int sentenceLength) {
            generateInputFromWord( word, placeInSentence, sentenceLength);
            generateHiddenLayer();
            generateOutput();
            return getPOSFromOutput();
        }

        // get value in output array
        public double[] getOutput() {
            return output;
        }

    }

和我的训练集样本:

wszyscy,1,7,RZECZOWNIK
jesteśmy,2,7,CZASOWNIK
studentami,3,7,RZECZOWNIK
lub,4,7,SPOJNIK
od,5,7,PRZYIMEK
niedawna,6,7,PRZYSLOWEK
absolwentami,7,7,RZECZOWNIK

反向传播算法正确吗? 我根据整个训练集计算平均权重。 新的权重在这里计算:

private double[][] calculateNewWeightsForHiddenLayer(double[] expectedOutput)
private double[][] calculateNewWeightsForInput(double[] expectedOutput)

1 个答案:

答案 0 :(得分:2)

在设计神经网络时,一项非常重要的任务是选择一组合适的功能,这些功能对于神经网络来说很容易使用。

通过将单词的位置编码为与字符本身相同的值,您将使网络无法学习任何内容,因为这会将许多可能不相关的单词映射到相同的编码。

通过为字符使用单个标量,使网络很难区分字符。相反,使用单点编码可能会更好,将每个字符表示为一个大多数为零的向量,并且在对应于该字符的索引处加上一个。

如果要在字符级别上工作并包括单词的位置,请使用其他方式对位置进行编码,例如通过将其编码为一组辅助输入,该输入可以是与单词位置的二进制表示相对应的1和0的向量。

通常,用于自然语言的神经网络使用所谓的单词嵌入,其中每个单词都映射到一个唯一的矢量表示形式,该矢量表示形式取决于单词出现的上下文(例如word2vec,Glove)。

对于语音标记的一部分,当前单词周围的单词也与标签分类有关。即使单词在句子中的位置相同,单词也会根据其周围的单词具有不同的标记。这就是为什么基于神经网络的语音标记的通用方法涉及递归神经网络的原因。