如何保存和恢复PyBrain神经网络?

时间:2014-09-25 21:18:39

标签: python save pybrain

我创建了一个简单的pybrain神经网络,我想要做的是保存训练和学习数据,以便神经网络可以继续学习。

这是我的代码。我无法弄清楚如何解决这个问题。

from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer    
from pybrain.datasets            import ClassificationDataSet
from pybrain.utilities           import percentError
from pybrain.tools.shortcuts     import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.structure.modules   import SoftmaxLayer    
from pylab import ion, ioff, figure, draw, contourf, clf, show, hold, plot
from scipy import diag, arange, meshgrid, where
from numpy.random import multivariate_normal
from pybrain.tools.shortcuts import buildNetwork
from pybrain.tools.xml.networkwriter import NetworkWriter
from pybrain.tools.xml.networkreader import NetworkReader
import csv

n = NetworkReader.readFrom('weatherlearned.csv') 

ds = SupervisedDataSet(6,1)
tf = open('weather.csv','r')


for line in tf.readlines():
    try:
        data = [float(x) for x in line.strip().split(',') if x != '']
        indata =  tuple(data[:6])
        outdata = tuple(data[6:])
        ds.addSample(indata,outdata)
    except ValueError,e:
            print "error",e,"on line"


n = buildNetwork(ds.indim,8,8,ds.outdim,recurrent=True)
t = BackpropTrainer(n,learningrate=0.005,momentum=0.05,verbose=True)
t.trainOnDataset(ds,1000)
t.testOnData(verbose=True)
tf.close()
NetworkWriter.writeToFile(n, 'weatherlearned.csv')

我也尝试过Pickel

from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer

from pybrain.datasets            import ClassificationDataSet
from pybrain.utilities           import percentError
from pybrain.tools.shortcuts     import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.structure.modules   import SoftmaxLayer

from pylab import ion, ioff, figure, draw, contourf, clf, show, hold, plot
from scipy import diag, arange, meshgrid, where
from numpy.random import multivariate_normal
from pybrain.tools.shortcuts import buildNetwork
from pybrain.tools.xml.networkwriter import NetworkWriter
from pybrain.tools.xml.networkreader import NetworkReader
import csv
from pybrain.tools.shortcuts import buildNetwork
import pickle


ds = SupervisedDataSet(6,1)
fileObject = open('weatherlearned.csv','r')
tf = open('weather.csv','r')

fileObject = open('weatherlearned.csv', 'w') 



for line in tf.readlines():
    try:
        data = [float(x) for x in line.strip().split(',') if x != '']
        indata =  tuple(data[:6])
        outdata = tuple(data[6:])
        ds.addSample(indata,outdata)
    except ValueError,e:
            print "error",e,"on line"


n = buildNetwork(ds.indim,8,8,ds.outdim,recurrent=True)
t = BackpropTrainer(n,learningrate=0.05,momentum=0.05,verbose=True)
pickle.dump(n, fileObject)
t.trainOnDataset(ds,1000)
t.testOnData(verbose=True)
tf.close()
fileObject = open('weatherlearned.csv','r')
n = pickle.load(fileObject)

1 个答案:

答案 0 :(得分:0)

查看注释中的建议链接,他们说的是从文件加载网络的方法与我的回答相同,但是如果从文件加载网络后留在buildNetwork函数中,则可能是问题所在,即也许您的代码中需要这样的东西:

        <ul style="width:100%;font-family: calibri;font-size: 15px;">
                <?php if ($com['Candidate_PG_Qualification'])!=""?>
                  <li><?php echo $com['Candidate_PG_Qualification']."         ".$com['Candidate_PG_Specialisation']." ".$com['Candidate_PG_University']." ".$com['Candidate_PG_YOP']; ?></li>;
                <?php ?>
                <li style="line-height: 1.6em; padding:10px;color:black;"><?php echo $com['Candidate_UG_Qualification']." ".$com['Candidate_UG_Specialisation']." ".$com['Candidate_UG_University']." ".$com['Candidate_UG_YOP']; ?></li>
                <li style="line-height: 1.6em; padding:10px;color:black;"><?php echo $com['HSC_Board_of_Education']." ".$com['Candidate_HSC_YOP'];; ?></li>
                <li style="line-height: 1.6em; padding:10px;color:black;"><?php echo $com['SSLC_Board_of_Education']." ".$com['Candidate_SSLC_YOP']; ?></li>
                </ul>

在pybrain文档中,它说明了buildNetwork ...

  

网络已经用随机值初始化

因此,如果在先前训练的网络加载后调用该函数,则它将仅将网络重新初始化为随机权重。从文件中加载网络不需要使用buildNetwork函数。