与Pybrain的神经网络不会收敛

时间:2014-10-23 17:51:17

标签: python python-2.7 pybrain

我正在尝试使用Python和Pybrain包构建一个简单的神经网络。 因为我开始学习方法和Pybrain包。我试图用一些我可用的真实数据制作一个非常简单的神经网络!

我知道我的数据存在潜在的连接,但是代码根本没有收敛,并且训练后的结果对于我放在那里的任何真实验证数据集基本相同。下面是我的代码和一小部分数据。我有超过5000行数据可用已知g来训练我的网络,但是加入培训的点数并不重要。

from pybrain.tools.shortcuts import buildNetwork as bld
from pybrain.datasets import SupervisedDataSet as spds
from pybrain.supervised.trainers import BackpropTrainer as bpt
import numpy as np

u,g,r,i,z = np.loadtxt("dataset.dat",unpack=True)
data = spds(4,1)
net = bld(4,1000,1)
for i in range(0,len(umag)):
    data.addSample((u[i],r[i],i[i],z[i]),(g[i]))

trainer = bpt(net,data)
trainer.trainUntilConvergence(dataset=data,maxEpochs=300,validationProportion=0.5)
p = net.activate([17.136,15.812,15.693,15.675])
print p
#expected result 16.225
p = net.activate([19.382,17.684,17.511,17.435])
#   18.195  - expected result
print p

18.14981    15.10829    13.96468    -10.8685    13.20411
16.84580    15.17839    14.61974    14.44930     14.44493
16.70895    15.57959    15.28097    15.16538     15.19260
18.44166    16.32709    15.45345    15.14938     15.04544
18.03881    16.49129    15.96768    15.78446     15.77211
21.15679    18.66248    17.46381    16.97513     16.75475
19.25665    17.80023    17.18956    16.97563     16.94967
17.01522    16.08040    15.85172    15.81930     15.92262
19.21695    17.72263    17.17900    16.98280     16.97201
19.98507    18.56911    17.98143    17.80738     17.81714
16.94824    15.97417    15.70555    15.59221     15.64357
21.20893    19.40982    18.68114    18.46647     18.43065
18.72652    17.38880    16.93716    16.73246     16.75096
20.57421    19.55045    19.15475    18.99772     19.02503
22.48833    20.07709    18.68276    17.60561     17.09613
22.27604    20.34056    19.66521    19.37319     19.30457
20.58372    19.18035    18.64691    18.43370     18.39288
22.25103    20.74570    20.16532    19.94144     19.78580
22.49646    19.63043    18.39409    17.97594     17.77803
19.22686    17.55373    16.97127    16.76445     16.70418
20.44500    19.34502    18.96556    18.80437     18.78767
22.69331    21.19628    19.89190    19.39628     19.11377
19.51075    18.02397    17.46963    17.31436     17.27759
19.92604    18.49456    17.97421    17.83519     17.80557
19.18904    18.22256    17.84221    17.70319     17.64457
20.23186    18.43468    17.81423    17.60103 17.54677
19.86590    18.32822    17.75089    17.57386 17.53067
20.84188    19.78345    19.42506    19.27895     19.34572
22.14103    21.86670    21.74832    21.61244     21.99680
18.02018    16.69380    16.23947    16.12869     16.09864
19.92574    18.63316    18.15877    17.95703     17.90224

1 个答案:

答案 0 :(得分:1)

一般来说,如果我将数据缩放到介于0和1之间,或者更好地介于0.1和0.9之间,我会得到更好的结果。神经元输出通常介于0和1之间。您可以尝试将输入和输出缩放到此范围内,看看是否可以获得更好的结果。