PyBrain是一个python库,提供(以及其他)易于使用的人工神经网络。
我无法使用pickle或cPickle正确地序列化/反序列化PyBrain网络。
请参阅以下示例:
from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
import cPickle as pickle
import numpy as np
#generate some data
np.random.seed(93939393)
data = SupervisedDataSet(2, 1)
for x in xrange(10):
y = x * 3
z = x + y + 0.2 * np.random.randn()
data.addSample((x, y), (z,))
#build a network and train it
net1 = buildNetwork( data.indim, 2, data.outdim )
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True)
for i in xrange(4):
trainer1.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0])
这是上述代码的输出:
Total error: 201.501998476
value after 0 epochs: 2.79
Total error: 152.487616382
value after 1 epochs: 5.44
Total error: 120.48092561
value after 2 epochs: 7.56
Total error: 97.9884043452
value after 3 epochs: 8.41
如您所见,随着培训的进行,网络总误差会减少。您还可以看到预测值接近预期值12。
现在我们将进行类似的练习,但将包括序列化/反序列化:
print 'creating net2'
net2 = buildNetwork(data.indim, 2, data.outdim)
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0])
#So far, so good. Let's test pickle
pickle.dump(net2, open('testNetwork.dump', 'w'))
net2 = pickle.load(open('testNetwork.dump'))
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
print 'loaded net2 using pickle, continue training'
for i in xrange(1, 4):
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0])
这是该块的输出:
creating net2
Total error: 176.339378639
value after 1 epochs: 5.45
loaded net2 using pickle, continue training
Total error: 123.392181859
value after 1 epochs: 5.45
Total error: 94.2867637623
value after 2 epochs: 5.45
Total error: 78.076711114
value after 3 epochs: 5.45
如您所见,似乎训练对网络有一定影响(报告的总误差值继续减小),但网络的输出值冻结了与第一次训练迭代相关的值。
我需要注意哪些缓存机制会导致这种错误行为?是否有更好的方法来序列化/反序列化pybrain网络?
相关版本号:
P.S。我在项目的网站上创建了a bug report,并将保持SO和bug跟踪器更新
答案 0 :(得分:11)
<强>原因强>
导致此行为的机制是在PyBrain模块中处理参数(.params
)和派生(.derivs
):事实上,所有网络参数都存储在一个数组中,但个人{ {1}}或Module
个对象可以访问“他们自己的”Connection
,但这只是整个数组切片上的一个视图。这允许在同一数据结构上进行本地和网络范围的写入和读出。
显然,这个切片视图链接会因为腌渍而破坏。
<强>解决方案强>
插入
.params
从文件加载后(重新创建此共享),它应该可以工作。