在Python中实现感知器算法时遇到问题

时间:2018-08-13 23:31:11

标签: python-3.x machine-learning neural-network perceptron

由于某些原因,我在调试以下代码时遇到问题。感知器在经过几步以权重值作为随机值的步骤后,停止更新自身。我尝试不使用班级来进行工作,并且将所有内容都编辑到最低限度,但仍然存在相同的问题。我还检查了Perceptron.train(),它工作正常。所以,我猜主要的问题是火车功能本身。我是python编程的新手,所以对您有所帮助。     随机导入     将图导入为plt     将numpy导入为np

#-----Function Of the line that seperates the two different Data Types-----$
def f(x):
    return x

#-----Activation Function-----#
def act(x):
    if x >= 0:
        return 1.0
    return 0.0

class Point:

    def __init__(self, x, y):
        self.X = x
        self.Y = y
        if y > f(x):
            self.Target = 1.0
        else:
            self.Target = 0.0


class Perceptron:

    def __init__(self, n, actFunc = act, lr = 0.2):
        self.Weights = [0 for i in range(n)]
        self.ActFunc = actFunc
        self.LR = lr

    def guess(self, inputs):
        valSum = 0
        for i in range(len(inputs)):
            valSum += self.Weights[i] * inputs[i]
        return self.ActFunc(valSum)

    def train(self, inputs, target):
        cal = self.guess(inputs)
        err = target - cal
        for i in range(0, len(self.Weights)):
            self.Weights[i] += self.LR * err * inputs[i]

    def printWeights(self):
        for i in range(len(self.Weights)):
            print("WEIGHT[" + str(i) + "] = " + str(self.Weights[i]))
        print("")

    def lineFunc(self):
        # y = w0 + w1x + w2y
        # (1 - w2)y = w0 + w1x
        # y = w0/(1-w2) + w1/(1 - w2)x
        w0 = self.Weights[0]
        w1 = self.Weights[1]
        w2 = self.Weights[2]
        return (str(w0/(1 - w2)) + " + " + str(w1/(1 - w2)) + " * x")

#-----INITIALISING DATA------#
brain = Perceptron(3)

n = 20
points = [Point(random.uniform(-10, 10), random.uniform(-10, 10)) for x in range(n)]

t = 1000

#-----Training-----#
for i in range(t):
    point = points[random.randrange(0, n)]
    brain.train([1, point.X, point.Y], point.Target)
    brain.printWeights()
    print(brain.lineFunc())

1 个答案:

答案 0 :(得分:0)

我自己发现了问题。 LineFunc()方法中存在错误。返回值错误,应该是:

    return (str(-w0/w2) + " + " + str(-w1/w2) + " * x")