如何使用Scala Breeze和LBFGS编写Logistic回归?

时间:2016-08-13 17:27:10

标签: scala logistic-regression scala-breeze

我正在编写关于Logistic回归的Scala代码。我正在尝试使用Scala Breeze,但是当我尝试使用代码时出现错误。

这是我的代码。基本上我试着写同样的like here

import breeze.linalg._
import breeze.optimize._

def lbfgsSolve(features: CSCMatrix[Double], outputs: SparseVector[Double], lambda: Double = 0.0) = {
      val obj = new DiffFunction[SparseVector[Double]] {
        override def calculate(weights: SparseVector[Double]): (Double, SparseVector[Double]) = {

          def sigmoid(w: SparseVector[Double]): SparseVector[Double] = {
            w.map( x => if (x>0) (1.0 / (1.0 + Math.exp(-x))) else (Math.exp(x) / (1.0 + Math.exp(x))))
          }

          val m = features.rows.toDouble
          val z = features * weights
          val yz = outputs :* z
          val theta = weights.copy
          theta(0) = 0.0

          val out = yz.map { x => if (x > 0) Math.log(Math.exp(-x) + 1.0) else (-x + Math.log(Math.exp(x) + 1.0)) }
          val loss = (sum(out) / m) + (0.5 * lambda * (weights.t * weights))

          val zz = sigmoid(yz)
          val z0 = (zz - 1.0) :* outputs
          val gradient = ((features.t * z0) / m) + (theta * lambda)

          (loss, gradient)

        }
      }
      val initWeights = SparseVector(Array.fill(features.cols)(1.0))
      new LBFGS[SparseVector[Double]](tolerance = 0.01).minimize(obj, initWeights)

    }

我正在使用this dataset来测试我的代码。 "承认"是输出,其他是功能。我已经通过此标准化了功能(en.wikipedia.org/wiki/Feature_scaling),并在第一列添加了常量数字。所以我的特征矩阵看起来像这样。

scala> features(0 until 4, 0 until 4)
res18: breeze.linalg.SliceMatrix[Int,Int,Double] =
1.0  0.27586206896551724  0.7758620689655172  0.6666666666666666
1.0  0.7586206896551724   0.8103448275862069  0.6666666666666666
1.0  1.0                  1.0                 0.0
1.0  0.7241379310344828   0.5344827586206897  1.0

当我在我的函数中应用此功能和输出时,它会返回错误。

scala> val answer = lbfgsSolve(features, outputs, 0.05)
[run-main-0] INFO breeze.optimize.LBFGS - Step Size: 14.07
[run-main-0] INFO breeze.optimize.LBFGS - Val and Grad Norm: 0.566596 (rel: 0.0459) 0.0357693
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.1 fval: 0.5684918452517186 rhs: 0.5665961695023995 cdd: -0.024619310343914663
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.010859375861321754 fval: 0.5667857913421284 rhs: 0.5665963991719202 cdd: -0.025639964206204915
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.0012166564062106571 fval: 0.5666174432637753 rhs: 0.5665964240162465 cdd: -0.025750856078277946
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.36780897002546E-4 fval: 0.5665987873427688 rhs: 0.5665964267985301 cdd: -0.025763280624522322
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.538334048641963E-5 fval: 0.5665966925628935 rhs: 0.5665964271113091 cdd: -0.025764677442936045
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.730193655418023E-6 fval: 0.5665964570019302 rhs: 0.5665964271464863 cdd: -0.025764834539024242
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.9459913173710706E-7 fval: 0.5665964305083527 rhs: 0.5665964271504427 cdd: -0.025764852207922004
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 2.1887054476776026E-8 fval: 0.5665964275285602 rhs: 0.5665964271508876 cdd: -0.02576485419518635
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 2.4616922533245922E-9 fval: 0.5665964271934153 rhs: 0.5665964271509377 cdd: -0.025764854418698992
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 2.7687332010108064E-10 fval: 0.5665964271557209 rhs: 0.5665964271509434 cdd: -0.025764854443838008
[run-main-0] ERROR breeze.optimize.LBFGS - Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.10402596383408635 fval: 0.5666979644593371 rhs: 0.5665964138414288 cdd: -0.0012757889755073697
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.011003739451576303 fval: 0.5666071180772083 rhs: 0.5665964257430798 cdd: -0.0012790548704777202
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.001166003720069149 fval: 0.5665975594520045 rhs: 0.5665964270017607 cdd: -0.001279400679810798
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.2357771173880845E-4 fval: 0.5665965471504739 rhs: 0.566596427135133 cdd: -0.0012794373271490386
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.309751490251532E-5 fval: 0.5665964398691514 rhs: 0.5665964271492683 cdd: -0.0012794412112233761
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.388156908600053E-6 fval: 0.5665964284988991 rhs: 0.5665964271507664 cdd: -0.0012794416228816064
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.4712560321674266E-7 fval: 0.5665964272938085 rhs: 0.5665964271509252 cdd: -0.0012794416665117385
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.5593311038744415E-8 fval: 0.5665964271660856 rhs: 0.566596427150942 cdd: -0.0012794416711359338
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.6526860200416176E-9 fval: 0.5665964271525491 rhs: 0.5665964271509438 cdd: -0.0012794416716260355
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.7514997931281254E-10 fval: 0.5665964271511141 rhs: 0.566596427150944 cdd: -0.00127944167167798
[run-main-0] ERROR breeze.optimize.LBFGS - Failure again! Giving up and returning. Maybe the objective is just poorly behaved?
[run-main-0] INFO breeze.optimize.LBFGS - Converged because line search failed!
answer: breeze.linalg.SparseVector[Double] = SparseVector((0,1.281051822587718), (1,0.47902016540238035), (2,0.4807986641770212), (3,0.3835424930764545))

如何解决此问题并使其有效? 我想尽可能使用CSCMatrix和SparseVector。

我使用Scala 2.11.8,微风0.12,slf4j-simple 1.7.6和openjdk-1.8.0

1 个答案:

答案 0 :(得分:0)

没有仔细看,你的渐变可能是错误的。尝试使用GradientTester类,它将打印

的诊断信息