我试图创建一个简单的反向传播脚本。就目前而言,我的代码将经历一次迭代(即for j in xrange(1)
),但如果我尝试添加更多迭代,则脚本会中断。
我想我正在改变一些变量的格式,但我不确定为什么或如何解决这个问题。
import numpy as np
from sympy import *
import math
def nonlin(x,deriv=False):
if(deriv==True):
return x*(1-x)
return 1/(1+np.exp(-x))
def frange(start, stop, step):
i = start
while i < stop:
yield i
i += step
A = []
B = []
for i in frange(-2,2,0.1):
t = np.exp(-abs(i))*sin(i*math.pi)
A.append(i)
B.append(t)
X = np.array(A)
y = np.array(B)
# randomly initialize our weights with uniformly distributed [-.5,.5]
weight = np.random.uniform([-.5,.5])
syn0 = weight[0]
syn1 = weight[1]
for j in xrange(2):
# Feed forward through layers 0, 1, and 2
a0 = X
a1 = nonlin(np.dot(a0,syn0))
a2 = nonlin(np.dot(a1,syn1))
# how much did we miss the target value?
a2_error = y - a2
if (j% 100) == 0:
print "Error:" + str(np.mean(np.abs(a2_error)))
# in what direction is the target value?
# were we really sure? if so, don't change too much.
s2_delta = a2_error*nonlin(a2,deriv=True)
# how much did each k1 value contribute to the k2 error (according to the weights)?
a1_error = s2_delta.dot(syn1.T)
# in what direction is the target k1?
# were we really sure? if so, don't change too much.
s1_delta = a1_error * nonlin(a1,deriv=True)
syn1 += a1.T.dot(s2_delta)
syn0 += a0.T.dot(s1_delta)
经过一次迭代后,我打印出一个错误(这很好),但是在进一步运行循环时会出现以下错误:
Traceback (most recent call last):
File "HW4.py", line 40, in <module>
a1 = nonlin(np.dot(a0,syn0))
File "HW4.py", line 9, in nonlin
return 1/(1+np.exp(-x))