有人可以解释我在Knime工具中实现的Logistic回归

时间:2017-05-26 13:56:08

标签: java logistic-regression knime

我一直在浏览逻辑回归源代码一段时间了。见https://github.com/knime/knime-core/blob/master/org.knime.base/src/org/knime/base/node/mine/regression/logistic/learner3/Learner.java

我对Logistic回归的理解是

1) Inialise a weight vector.
2) for each of the instance (until convergence following steps proceed)
       find response using y = (1/exp(-W_transp*X))   
       find gradient error = y - comp_y         //target - computed y
       append d = d + error*x;    // d is a vector initialised to 0 aftr evry epoch
       weight = weight + learningrate * d  //after every epoch updated

我不明白 irlsRls 方法中的复杂代码究竟是什么。填充了Array2DRowRealMatrix xTwx和xTyu中的两个,但完成的机制不清楚,并且没有提到这些使用的内容。

for (int k = 0; k < tcC - 1; k++) {
                for (int i = 0; i < rC + 1; i++) {
                    for (int ii = i; ii < rC + 1; ii++) {
                        int o = k * (rC + 1);
                        double v = xTwx.getEntry(o + i, o + ii);
                        double w = pi.getEntry(0, k) * (1 - pi.getEntry(0, k));
                        v += x.getEntry(0, i) * w * x.getEntry(0, ii);
                        xTwx.setEntry(o + i, o + ii, v);
                        xTwx.setEntry(o + ii, o + i, v);
                    }
                }
            }
            // fill the rest of xTwx (k != k')
            for (int k = 0; k < tcC - 1; k++) {
                for (int kk = k + 1; kk < tcC - 1; kk++) {
                    for (int i = 0; i < rC + 1; i++) {
                        for (int ii = i; ii < rC + 1; ii++) {
                            int o1 = k * (rC + 1);
                            int o2 = kk * (rC + 1);
                            double v = xTwx.getEntry(o1 + i, o2 + ii);
                            double w = -pi.getEntry(0, k) * pi.getEntry(0, kk);
                            v += x.getEntry(0, i) * w * x.getEntry(0, ii);
                            xTwx.setEntry(o1 + i, o2 + ii, v);
                            xTwx.setEntry(o1 + ii, o2 + i, v);
                            xTwx.setEntry(o2 + ii, o1 + i, v);
                            xTwx.setEntry(o2 + i, o1 + ii, v);
                        }
                    }
                }
            }

            int g = (int)row.getTarget();
            // fill matrix xTyu
            for (int k = 0; k < tcC - 1; k++) {
                for (int i = 0; i < rC + 1; i++) {
                    int o = k * (rC + 1);
                    double v = xTyu.getEntry(o + i, 0);
                    double y = k == g ? 1 : 0;
                    v += (y - pi.getEntry(0, k)) * x.getEntry(0, i);
                    xTyu.setEntry(o + i, 0, v);
                }
}

有人可以指导。 Thanx提前。

0 个答案:

没有答案