手动进行softmax多类逻辑回归-这个权重更新公式正确吗?

时间:2018-09-03 13:00:29

标签: java machine-learning classification logistic-regression gradient-descent

对于一个项目,我正在尝试用Java编写多类逻辑回归。我一直在遵循this tutorial来尝试理解它,但是在编码模型的训练(尤其是权重更新)方面遇到了麻烦,

enter image description here

我对每个组件的维数感到困惑,例如sigma sum(是多维的?)和权重。

这是目前的代码(仅包括最相关的位)

/** the weight to learn */
private static double[][] weights;
static double sum[][] = new double[8][10];

public static class Instance {
    public int label;
    public double[] x;

    public Instance(int label, double[] x) {
        this.label = label;
        this.x = x;
    } //each instance contains a label int and a double array, x, containing the data itself
}

public void train(List<Instance> instances) {
     for (int n=0; n<ITERATIONS; n++) {    
        for (int j = 0; j < 8; j++) {   //resetting the sum as 0 for each iteration
                sum[j] = 0;
        }
        for (int i=0; i<instances.size(); i++) { //sets up the sigma function
            double[] x = instances.get(i).x;
            double[] predicted = classify(x);
            double[] label = onehotencoder(instances.get(i).label);

            double[] difference = new double[predicted.length];
            for (int a = 0; a < difference.length; a++) {
                difference[a] = label[a] - predicted[a];
            }

            double[] result = {0,0,0,0,0,0,0,0};
            for (int p = 0; p< result.length; p++) {
                for (int q = 0; q < result.length; q++) {
                    result[p] += x[p] * difference[q]; //dot product multiplying the difference by x (is this right??)
                }
            }

            for (int j = 0; j < 8; j++) { //doing the sigma summation 
                    sum[j] += result[j];
                }
            }

            for (int k=0; k<10; k++) { //updating weights with formula
                for (int j = 0; j < 8; j++) {
                weights[j][k] = weights[j][k] - rate * ((-1/instances.size()) * sum[j]);
            }
        }

    }
}

public double[] classify(double[] x) {
    double[] logit = new double[10];
    for (int j=0; j < 10; j++) {
        for (int i=0; i < 8;i++)  {
            logit[j] = weights[i][j] * x[i];
        }
    }
    double[] result = softmax(logit);
    return result;
}

softmax, standardisation functions etc. left out

我几乎可以肯定总和的权重是错误的,但是我不知道自己在做什么!任何帮助将不胜感激。

0 个答案:

没有答案