为什么我的逻辑回归实现如此缓慢?

时间:2014-04-21 16:55:11

标签: r statistics regression logistic-regression

以下是R(theoretical details here)中批量梯度下降算法的实现:

logreg = function(y, x) {
    x = as.matrix(x)
    x = apply(x, 2, scale)
    x = cbind(1, x)
    m = nrow(x)
    n = ncol(x)
    alpha = 2/m

    # b = matrix(rnorm(n))
    # b = matrix(summary(lm(y~x))$coef[, 1])
    b = matrix(rep(0, n))
    h = 1 / (1 + exp(-x %*% b))

    J = -(t(y) %*% log(h) + t(1-y) %*% log(1 -h))
    derivJ = t(x) %*% (h-y)


    niter = 0
    while(1) {
        niter = niter + 1
        newb = b - alpha * derivJ
        h = 1 / (1 + exp(-x %*% newb))
        newJ = -(t(y) %*% log(h) + t(0-y) %*% log(1 -h))
        while((newJ - J) >= 0) {
            print("inner while...")
            # step adjust
            alpha = alpha / 1.15
            newb = b - alpha * derivJ
            h = 1 / (1 + exp(-x %*% newb))
            newJ = -(t(y) %*% log(h) + t(1-y) %*% log(1 -h))
        }
        if(max(abs(b - newb)) < 0.001) {
            break
        }
        b = newb
        J = newJ
        derivJ = t(x) %*% (h-y)
    }
    b
    v = exp(-x %*% b)
    h = 1 / (1 + v)
    w = h^2 * v
    # # hessian matrix of cost function
    hess = t(x) %*% diag(as.vector(w)) %*% x
    seMat = sqrt(diag(solve(hess)))
    zscore = b / seMat
    cbind(b, zscore)
}

nr = 5000
nc = 3
# set.seed(17)
x = matrix(rnorm(nr*nc, 0, 999), nr)
x = apply(x, 2, scale)
# y = matrix(sample(0:1, nr, repl=T), nr)
h = 1/(1 + exp(-x %*% rnorm(nc)))
y = round(h)
y[1:round(nr/2)] = sample(0:1, round(nr/2), repl=T)


testglm = function() {
    for(i in 1:20) {
        res = summary(glm(y~x, family=binomial))$coef
    }
    print(res)
}

testlogreg = function() {
    for(i in 1:20) {
        res = logreg(y, x)
    }
    print(res)
}

print(system.time(testlogreg()))
print(system.time(testglm()))

算法给出了正确的结果,但速度慢了十倍。

print(system.time(testlogreg()))

           [,1]      [,2]
[1,] -0.0358877  -1.16332
[2,]  0.1904964   6.09873
[3,] -0.1428953  -4.62629
[4,] -0.9151143 -25.33478
   user  system elapsed 
  4.013   1.037   5.062 
#////////////////////////////////////////////////////
print(system.time(testglm()))
              Estimate Std. Error   z value     Pr(>|z|)
(Intercept) -0.0360447  0.0308617  -1.16794  2.42829e-01
x1           0.1912254  0.0312500   6.11922  9.40373e-10
x2          -0.1432585  0.0309001  -4.63618  3.54907e-06
x3          -0.9178177  0.0361598 -25.38226 3.95964e-142
   user  system elapsed 
  0.482   0.040   0.522 

但如果我不计算标准误差和z值,那么它比glm快一点:

#////////////////////////////////////////////////////
print(system.time(testlogreg()))
           [,1]
[1,] -0.0396199
[2,]  0.2281502
[3,] -0.3941912
[4,]  0.8456839
   user  system elapsed 
  0.404   0.001   0.405 
#////////////////////////////////////////////////////
print(system.time(testglm()))
              Estimate Std. Error   z value     Pr(>|z|)
(Intercept) -0.0397529  0.0309169  -1.28580  1.98514e-01
x1           0.2289063  0.0312998   7.31336  2.60551e-13
x2          -0.3956140  0.0319847 -12.36884  3.85328e-35
x3           0.8483669  0.0353760  23.98144 4.34358e-127
   user  system elapsed 
  0.474   0.000   0.475 

显然,se和z值的计算需要花费很多时间,但glm如何做呢?我该如何改进我的实施?

1 个答案:

答案 0 :(得分:3)

最后想出来了,秘诀在于使用稀疏矩阵(see also this blog post)。

require(Matrix)
logreg = function(y, x) {
    x = as.matrix(x)
    x = apply(x, 2, scale)
    x = cbind(1, x)
    m = nrow(x)
    n = ncol(x)
    alpha = 2/m

    # b = matrix(rnorm(n))
    # b = matrix(summary(lm(y~x))$coef[, 1])
    b = matrix(rep(0, n))
    v = exp(-x %*% b)
    h = 1 / (1 + v)

    J = -(t(y) %*% log(h) + t(1-y) %*% log(1 -h))
    derivJ = t(x) %*% (h-y)


    derivThresh = 0.0000001
    bThresh = 0.001
    while(1) {
        newb = b - alpha * derivJ
        if(max(abs(b - newb)) < bThresh) {
            break
        }
        v = exp(-x %*% newb)
        h = 1 / (1 + v)
        newderivJ = t(x) %*% (h-y)
        if(max(abs(newderivJ - derivJ)) < derivThresh) {
            break
        }
        newJ = -(t(y) %*% log(h) + t(0-y) %*% log(1 -h))
        if(newJ > J) {
            alpha = alpha/2
        }
        b = newb
        J = newJ
        derivJ = newderivJ
    }
    w = h^2 * v
    # # hessian matrix of cost function
    hess = t(x) %*% Diagonal(x = as.vector(w)) %*% x
    seMat = sqrt(diag(solve(hess)))
    zscore = b / seMat
    cbind(b, zscore)
}

nr = 5000
nc = 2
# set.seed(17)
x = matrix(rnorm(nr*nc, 3, 9), nr)
# x = apply(x, 2, scale)
# y = matrix(sample(0:1, nr, repl=T), nr)
h = 1/(1 + exp(-x %*% rnorm(nc)))
y = round(h)
y[1:round(nr/2)] = sample(0:1, round(nr/2), repl=T)


ntests = 13
testglm = function() {
    nr = 5000
    nc = 2
    # set.seed(17)
    x = matrix(rnorm(nr*nc, 3, 9), nr)
    # x = apply(x, 2, scale)
    # y = matrix(sample(0:1, nr, repl=T), nr)
    h = 1/(1 + exp(-x %*% rnorm(nc)))
    y = round(h)
    y[1:round(nr/2)] = sample(0:1, round(nr/2), repl=T)
    for(i in 1:ntests) {
        res = summary(glm(y~x, family=binomial))$coef[, c(1, 3)]
    }
    res
}

testlogreg = function() {
    nr = 5000
    nc = 2
    # set.seed(17)
    x = matrix(rnorm(nr*nc, 3, 9), nr)
    # x = apply(x, 2, scale)
    # y = matrix(sample(0:1, nr, repl=T), nr)
    h = 1/(1 + exp(-x %*% rnorm(nc)))
    y = round(h)
    y[1:round(nr/2)] = sample(0:1, round(nr/2), repl=T)
    for(i in 1:ntests) {
        res = logreg(y, x)
    }
    res
}

print(system.time(testlogreg()))
print(system.time(testglm()))

现在我的实现甚至比R中的glm快一点!

print(system.time(testlogreg()))
          [,1]       [,2]
[1,] -0.022598  -0.739494
[2,] -0.510799 -15.793676
[3,] -0.130177  -4.257121
[4,]  0.578318  17.712392
[5,]  0.299080   9.587985
[6,]  0.244131   7.888600
   user  system elapsed 
  8.954   0.044   9.000 
#////////////////////////////////////////////////////
print(system.time(testglm()))
              Estimate Std. Error    z value    Pr(>|z|)
(Intercept) -0.0226784  0.0305694  -0.741865 4.58169e-01
x1          -0.5129285  0.0323621 -15.849653 1.41358e-56
x2          -0.1305872  0.0305892  -4.269057 1.96301e-05
x3           0.5806001  0.0326719  17.770648 1.19304e-70
x4           0.3002898  0.0312072   9.622454 6.42789e-22
x5           0.2451543  0.0309601   7.918407 2.40573e-15
   user  system elapsed 
 12.218   0.008  12.229