Logistic回归给出不正确的结果

时间:2012-12-02 20:57:01

标签: ruby algorithm statistics machine-learning regression

我正在一个网站上工作,我收集了人们玩过的国际象棋游戏的结果。观察玩家的评分以及他们的评级与对手的评分之间的差异,我绘制了一个带有代表胜利(绿色),平局(蓝色)和损失(红色)的点的图表。

通过这些信息,我还实施了一个逻辑回归算法来对获胜和获胜/抽奖的截止值进行分类。使用评级和差异作为我的两个特征,我得到一个分类器,然后在图表上绘制分类器改变其预测的位置。

我的梯度下降代码,成本函数和sigmoid函数如下。

  def gradient_descent()
    oldJ = 0    
    newJ = J()
    alpha = 1.0     # Learning rate
    run = 0
    while (run < 100) do
      tmpTheta = Array.new
      for j in 0...numFeatures do
        sum = 0
        for i in 0...m do
          sum += ((h(training_data[:x][i]) - training_data[:y][i][0]) * training_data[:x][i][j])
        end
        tmpTheta[j] = Array.new
        tmpTheta[j][0] = theta[j, 0] - (alpha / m) * sum  # Alpha * partial derivative of J with respect to theta_j
      end
      self.theta = Matrix.rows(tmpTheta)
      oldJ = newJ
      newJ = J()
      run += 1
      if (run == 100 && (oldJ - newJ > 0.001)) then run -= 20 end   # Do 20 more if the error is still going down a fair amount.
      if (oldJ < newJ)
        alpha /= 10
      end
    end
  end

  def J()
    sum = 0
    for i in 0...m
      sum += ((training_data[:y][i][0] * Math.log(h(training_data[:x][i]))) 
          + ((1 - training_data[:y][i][0]) * Math.log(1 - h(training_data[:x][i]))))
    end
    return (-1.0 / m) * sum
  end

  def h(x)
    if (x.class != 'Matrix')    # In case it comes in as a row vector or an array
      x = Matrix.rows([x])      # [x] because if it's a row vector we want [[a, b]] to get an array whose first row is x.
    end
    x = x.transpose   # x is supposed to be a column vector, and theta^ a row vector, so theta^*x is a number.
    return g((theta.transpose * x)[0, 0])  # theta^ * x gives [[z]], so get [0, 0] of that for the number z.
  end

  def g(z)
    tmp = 1.0 / (1.0 + Math.exp(-z))   # Sigmoid function
    if (tmp == 1.0) then tmp = 0.99999 end    # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
    if (tmp == 0.0) then tmp = 0.00001 end
    return tmp
  end

当我在代表我自己的国际象棋资料的数据集上测试时,我得到了合理的结果,我可以满意:

12842311:  Correct result

有一段时间,我很高兴。我试过的所有例子都给出了有趣的图表。然后我尝试了一个名叫Kevin Cao的球员,他有超过250场比赛的名字,因此有超过1000场比赛,这是一个非常大的训练集。结果显然不正确:

12905349:  Incorrect result

嗯,这不好。所以我将初始学习率从1.0增加到100.0作为我的第一个想法。这对Kevin来说看起来是正确的结果:

12905349:  Correct result

不幸的是,当我在自己和我的小数据集上尝试它时,我得到了一个奇怪的现象,就是其中一个预测只是在0处给出了一条平线:

12842311:  Incorrect result

我检查了θ,它说是[[2.3707682771730836],[21.22408286825226],[ - 19081.906528679192]]。第三个训练变量(实际上是第二个,因为x_0 = 1)是等级的差异,所以当差异只是最小的正位时,逻辑回归的公式变为负值,而sigmoid函数预测y = 0。差异只是稍微有点积极,同样,它会跳跃并预测y = 1.

我将初始学习率从100.0降低到1.0,并决定尝试更慢地减少它。因此,当成本函数增加时,不是将其减少十倍,而是将其减少了两倍。

不幸的是,这并没有改变我的结果。即使我将梯度下降的循环次数从100增加到1000,它仍然可以预测错误的结果。

我仍然是逻辑回归的初学者(我刚刚在coursera上完成了机器学习课程,这是我第一次尝试实现我在那里学到的任何算法),所以我已经达到了我的直觉。如果有人能帮我弄清楚这里出了什么问题,我做错了什么,以及如何解决它我会非常感激。

编辑:我也尝试了另一个数据集,它有大约300个数据点,并再次得到一条平绿线和一条普通的蓝线。两者的算法基本相同,只是y的一些不同结果,因为我正在进行多类分类。

编辑:由于人们已经要求它,对于平坦线条的每一次梯度下降迭代,这里是J,Alpha和Theta:

J: 1.7679949412730092  Alpha: 1.0  Theta: Matrix[[-0.004477611940298508], [0.2835820895522388], [-123.63880597014925]]
J: 0.6873432218114784  Alpha: 0.1  Theta: Matrix[[-0.008057848266678727], [-8.033992854843122], [-118.62571350649955]]
J: 2.7493579020963597  Alpha: 0.1  Theta: Matrix[[0.0035837099422764904], [10.036108977992713], [-114.29679460799208]]
J: 2.5431564907845736  Alpha: 0.01  Theta: Matrix[[0.002061352330336195], [7.255061503962862], [-113.88091708799209]]
J: 2.268221136398013  Alpha: 0.01  Theta: Matrix[[0.0008076454646645536], [4.923257856798684], [-113.43169704202194]]
J: 2.02765281325063  Alpha: 0.01  Theta: Matrix[[-0.00014755931145485107], [3.0843409102315205], [-112.95644762679805]]
J: 1.821451342237053  Alpha: 0.01  Theta: Matrix[[-0.0008639634905593289], [1.6548476959031622], [-112.46627318829059]]
J: 1.8214513720879484  Alpha: 0.01  Theta: Matrix[[-0.0013117163263802246], [0.6758826956046561], [-111.9660989569473]]
J: 1.8214513720879484  Alpha: 0.001  Theta: Matrix[[-0.0013535066248876874], [0.5834935043210742], [-111.91600392423089]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0013952969233951501], [0.49110431303749225], [-111.86590889151448]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0014341021771264934], [0.40365238581361185], [-111.81578997843985]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0014729074308578367], [0.31620045858973145], [-111.76567106536523]]
J: 1.752717488714965  Alpha: 0.001  Theta: Matrix[[-0.0015115010626209136], [0.22904945780472585], [-111.71555130580272]]
J: 1.752717488714965  Alpha: 0.001  Theta: Matrix[[-0.001544336226800018], [0.15110191314800955], [-111.66540851236988]]
J: 1.770809597429665  Alpha: 0.001  Theta: Matrix[[-0.0015771713909791226], [0.07315436849129325], [-111.61526571893704]]
J: 1.7297985323807161  Alpha: 0.0001  Theta: Matrix[[-0.00158045491336022], [0.06535960382896211], [-111.61025143962061]]
J: 1.718350722631126  Alpha: 0.0001  Theta: Matrix[[-0.0015837319880072584], [0.05757622586497872], [-111.60523715385645]]
J: 1.7183505768797593  Alpha: 0.0001  Theta: Matrix[[-0.0015867170175074515], [0.05030859963032436], [-111.60022257604714]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0015897020324328638], [0.04304099913473299], [-111.59520799822326]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0015926870473582369], [0.03577339863921061], [-111.59019342039937]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.00159567206228361], [0.028505798143688237], [-111.58517884257549]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001598657077208983], [0.02123819764816586], [-111.5801642647516]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001601642092134356], [0.013970597152643486], [-111.57514968692772]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001604627107059729], [0.006702996657121109], [-111.57013510910383]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016076121219851022], [-0.0005646038384012671], [-111.56512053127994]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016105971369104752], [-0.007832204333923645], [-111.56010595345606]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016135821518358483], [-0.01509980482944602], [-111.55509137563217]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016165671667612213], [-0.022367405324968396], [-111.55007679780829]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016195521816865944], [-0.02963500582049077], [-111.5450622199844]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016225371966119674], [-0.03690260631601315], [-111.54004764216052]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016255222115373405], [-0.04417020681153553], [-111.53503306433663]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016285072264627136], [-0.05143780730705791], [-111.53001848651274]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016314922443731613], [-0.05870541239661013], [-111.52500390868587]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016344772622834192], [-0.06597301748587016], [-111.519989330859]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016374622664495802], [-0.07324060142296517], [-111.51497475304588]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001640217664533409], [-0.08015482159935092], [-111.50996040483884]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016455906875599943], [-0.0937712290880118], [-111.49993184619791]]
J: 1.994702022407994  Alpha: 0.0001  Theta: Matrix[[-0.0016482771980077554], [-0.10057943119248941], [-111.49491756687851]]
J: 1.9789198631246232  Alpha: 1.0e-05  Theta: Matrix[[-0.0016485458502465615], [-0.10126025363935508], [-111.49441613894419]]
J: 1.948354991984789  Alpha: 1.0e-05  Theta: Matrix[[-0.0016490831547241735], [-0.10262189853308641], [-111.49341328307554]]
J: 1.9331013621188657  Alpha: 1.0e-05  Theta: Matrix[[-0.0016493518069629796], [-0.10330272097995208], [-111.49291185514122]]
J: 1.9178620371528292  Alpha: 1.0e-05  Theta: Matrix[[-0.0016496204592017856], [-0.10398354342681772], [-111.49241042720689]]
J: 1.902623825636303  Alpha: 1.0e-05  Theta: Matrix[[-0.0016498891114405914], [-0.10466436587368326], [-111.49190899927257]]
J: 1.8873858680247269  Alpha: 1.0e-05  Theta: Matrix[[-0.0016501577636793972], [-0.10534518832054848], [-111.49140757133824]]
J: 1.8721478527437034  Alpha: 1.0e-05  Theta: Matrix[[-0.0016504264159182024], [-0.10602601076741257], [-111.49090614340392]]
J: 1.8569098083540256  Alpha: 1.0e-05  Theta: Matrix[[-0.0016506950681570054], [-0.10670683321427255], [-111.4904047154696]]
J: 1.8416717846532462  Alpha: 1.0e-05  Theta: Matrix[[-0.0016509637203958004], [-0.10738765566111781], [-111.48990328753527]]
J: 1.8264337702403803  Alpha: 1.0e-05  Theta: Matrix[[-0.0016512323726345674], [-0.10806847810791036], [-111.48940185960095]]
J: 1.8111957469624462  Alpha: 1.0e-05  Theta: Matrix[[-0.0016515010251717409], [-0.1087493010703349], [-111.48890043166602]]
J: 1.7959577228777213  Alpha: 1.0e-05  Theta: Matrix[[-0.001651769677708553], [-0.10943012403208266], [-111.4883990037311]]
J: 1.7807196990939538  Alpha: 1.0e-05  Theta: Matrix[[-0.0016520383302440706], [-0.11011094699140556], [-111.48789757579618]]
J: 1.7654816767669712  Alpha: 1.0e-05  Theta: Matrix[[-0.0016523069827749494], [-0.11079176994204029], [-111.48739614786128]]
J: 1.7197677244765115  Alpha: 1.0e-05  Theta: Matrix[[-0.0016531129399852717], [-0.11283423807786983], [-111.4858918640573]]
J: 1.7045300185036796  Alpha: 1.0e-05  Theta: Matrix[[-0.0016533815914621833], [-0.11351505905442376], [-111.48539043612449]]
J: 1.689293134633683  Alpha: 1.0e-05  Theta: Matrix[[-0.0016536502402002386], [-0.11419587490110002], [-111.48488900819716]]
J: 1.674059195452273  Alpha: 1.0e-05  Theta: Matrix[[-0.001653918879126327], [-0.1148766723699622], [-111.48438758028945]]
J: 1.6588357959146847  Alpha: 1.0e-05  Theta: Matrix[[-0.0016541874829120791], [-0.11555740402097447], [-111.48388615245203]]
J: 1.6436500186219352  Alpha: 1.0e-05  Theta: Matrix[[-0.0016544559609891405], [-0.1162379002196091], [-111.48338472486603]]
J: 1.6285972611659707  Alpha: 1.0e-05  Theta: Matrix[[-0.001654723991174496], [-0.11691755751707966], [-111.4828832981758]]
J: 1.6139994752963014  Alpha: 1.0e-05  Theta: Matrix[[-0.0016549904481917704], [-0.11759426827073645], [-111.48238187463193]]
J: 1.600799606845299  Alpha: 1.0e-05  Theta: Matrix[[-0.0016552516449943116], [-0.11826112664220582], [-111.48188046160847]]
J: 1.5908244528084288  Alpha: 1.0e-05  Theta: Matrix[[-0.0016554977759847996], [-0.1188997667477244], [-111.48137907871664]]
J: 1.5851960976828814  Alpha: 1.0e-05  Theta: Matrix[[-0.0016557144987826046], [-0.11948332530842007], [-111.4808777546412]]
J: 1.5826817076400923  Alpha: 1.0e-05  Theta: Matrix[[-0.0016558999497352893], [-0.12000831170339445], [-111.48037649310945]]
J: 1.5816354848004566  Alpha: 1.0e-05  Theta: Matrix[[-0.0016560658987327093], [-0.12049677093659837], [-111.4798752705816]]
J: 1.581199878569286  Alpha: 1.0e-05  Theta: Matrix[[-0.0016562224426970157], [-0.12096761454376066], [-111.47937406686383]]
J: 1.5810169018926878  Alpha: 1.0e-05  Theta: Matrix[[-0.0016563748211790893], [-0.12143065620486218], [-111.47887287147701]]
J: 1.5809396242131868  Alpha: 1.0e-05  Theta: Matrix[[-0.0016565254040880424], [-0.1218903347622732], [-111.47837167968135]]
J: 1.5809069017613124  Alpha: 1.0e-05  Theta: Matrix[[-0.0016566752202995195], [-0.12234857730581448], [-111.47787048941908]]
J: 1.5808930296490606  Alpha: 1.0e-05  Theta: Matrix[[-0.001656824710233385], [-0.12280620875454971], [-111.47736929980935]]
J: 1.580887145848097  Alpha: 1.0e-05  Theta: Matrix[[-0.0016569740612930289], [-0.12326358014294572], [-111.47686811047738]]
J: 1.580884649719601  Alpha: 1.0e-05  Theta: Matrix[[-0.0016571233527736234], [-0.12372084005243131], [-111.47636692126457]]
J: 1.5808835906710963  Alpha: 1.0e-05  Theta: Matrix[[-0.0016572726175860411], [-0.12417805026085695], [-111.47586573210509]]
J: 1.5808831413239819  Alpha: 1.0e-05  Theta: Matrix[[-0.00165742186803091], [-0.12463523410670607], [-111.47536454297435]]
.........

对于创建正确预测的那个:

J: 4.330234652497978  Alpha: 1.0  Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
J: 4.330234652497978  Alpha: 0.1  Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
J: 4.2958677406623815  Alpha: 0.1  Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
J: 0.4467735852246924  Alpha: 0.1  Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
J: 3.0930257965656063  Alpha: 0.01  Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
J: 2.7493567080605392  Alpha: 0.01  Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
J: 2.508788325211366  Alpha: 0.01  Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
J: 2.405687589704577  Alpha: 0.01  Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
J: 2.268219942362192  Alpha: 0.01  Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
J: 2.1307522353180164  Alpha: 0.01  Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
J: 2.027651529662123  Alpha: 0.01  Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
J: 1.9589177059909308  Alpha: 0.01  Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
J: 1.8558169406332465  Alpha: 0.01  Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
J: 1.8214500586485458  Alpha: 0.01  Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
J: 1.8214500884994413  Alpha: 0.01  Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
J: 1.8677695240029366  Alpha: 0.001  Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
J: 1.8462563443835032  Alpha: 0.0001  Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
J: 1.8247430163841476  Alpha: 0.0001  Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
J: 1.803243007740144  Alpha: 0.0001  Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
J: 1.7875423426167685  Alpha: 0.0001  Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
J: 1.7870839229503594  Alpha: 0.0001  Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
J: 1.7870831481868632  Alpha: 0.0001  Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
J: 1.7870831468153818  Alpha: 0.0001  Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
J: 1.7870831468129538  Alpha: 0.0001  Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
......

编辑:我注意到假设的第一次迭代总是预测0.5,因为theta全是0。但之后它总是预测1或0(0.00001或0.99999以避免我的代码中不存在的对数)。这对我来说似乎并不合适 - 过于自信 - 而且可能是导致这种情况无效的关键。

4 个答案:

答案 0 :(得分:3)

有一些关于你的实现的东西有点不合标准。

  1. 首先,逻辑回归目标通常作为

    的最小化问题给出

    lr(x[n],y[n])=log(1+exp(-y[n]*dot(w[n],x[n]))) 其中y[n]1-1

    您似乎正在使用

    的等效最大化问题表达式

    lr(x[n],y[n])=-y[n]*log(1+exp(-dot(w[n],x[n])))+(1-y[n])*(-dot(w[n],x[n])-log(1+exp(-dot(w[n],x[n])))

    其中y[n]为0或1(此配方中的y [n] = 0相当于第一个配方中的y [n] = 1)。

    因此,您应确保在数据集中,您的标签为0或1且 1或-1。

  2. 接下来,LR目标通常不会除以m(数据集的大小)。当您将逻辑回归视为概率模型时,此缩放因子不正确。

  3. 最后,您的实现可能存在一些数值问题(您尝试在g函数中进行更正)。 Leon Bottou的sgd代码(http://leon.bottou.org/projects/sgd)对损失函数和导数进行了一些更稳定的计算,如下所示(在C代码中 - 他使用我提到的第一个LR公式): / p>

    /* logloss(a,y) = log(1+exp(-a*y)) */
    double loss(double a, double y)
    {
      double z = a * y;
      if (z > 18) {
        return exp(-z);
      }
      if (z < -18) {
        return -z;
      }
      return log(1 + exp(-z));
    }
    
    /*  -dloss(a,y)/da */
    double dloss(double a, double y)
    {
      double z = a * y;
      if (z > 18) {
        return y * exp(-z);
      }
      if (z < -18){
        return y;
      }
      return y / (1 + exp(z));
    }
    
  4. 您还应该考虑运行一个股票l-bfgs例程(我不熟悉Ruby实现),这样您就可以专注于使目标和梯度计算正确,而不必担心学习率等问题。

答案 1 :(得分:1)

一些想法:

  • 我认为如果您能够显示J()alpha的某些迭代的值,将会很有帮助。
  • 您是否包含常量(偏差)作为功能?如果我没记错的话,如果你不这样做,h() == 0.5的(直线)将被迫通过零

  • 您的功能J()似乎正在返回否定对数可能性(因此您希望最小化)。然而,你降低了学习率if (oldJ < newJ),即如果J()变大,即更糟。

答案 2 :(得分:0)

浮点数:

试试这个? Equal花车之间的比较对我来说没有多大意义。

def g(z)
    tmp = 1.0 / (1.0 + Math.exp(-z))   # Sigmoid function
    if (tmp >= 0.99999) then tmp = 0.99999 end    # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
    if (tmp <= 0.00001) then tmp = 0.00001 end
    return tmp
end

功能缩放

你提到你正在使用两个功能,我认为它们是玩家自己的等级和等级差异。这是对的吗?

还考虑使用一些feature scaling作为数据预处理步骤,例如

enter image description here。或者你可以通过使数据中每个特征的值具有零均值和单位方差来进行标准化方法。

问题:

  • 图表中的蓝线和绿线有什么区别?
  • 您是否尝试以非常小的学习率开始,例如0.01或0.001?
  • 如果您只使用固定的学习率,会有什么行为?尝试0.001,0.01,0.1,0.5,1,10等。请在此处发布您的结果。

答案 3 :(得分:0)

我认为您需要使用特征规范化((X-mu)/ sigma)对初始数据集进行规范化,然后执行您打算执行的操作,而不是使用您的学习率。

如果没有特征规范化,那么对于大型数据集来说,梯度下降会变得非常错误。