扩展简单的NN示例以在mnist上工作

时间:2019-06-01 15:28:33

标签: python numpy neural-network deep-learning mnist

我试图了解如何使用numpy和mnist构建简单的NN。为此,我通过添加图层并更改其中的神经元数量,从Internet扩展了一个简单示例。

问题是,网络似乎没有学习。最后,我绘制了每个时期的错误分类图,在第一个时期下降之后,然后保持不变。

由于我是一个完整的初学者,所以我不确定设计的网络是否还可以识别数字。同样,我对错误分类进行刻画的方式似乎也存在缺陷,因为正确的预测只是输出分布中最有可能出现的数字。

import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf


def nonlin(x, deriv=False):
    if deriv:
        return x * (1-x)
    return 1 / (1 + np.exp(-x))


# converts normal labels 0,1,2, ... ,9 into the output representation
def binary_y(datapoints):
    out = []
    for data in datapoints:
        l = [0 for i in range(data)] + [1]
        r = [0 for i in range(9-data)]
        b = l+r
        out.append(np.array(b))
    return np.array(out)


# turns output distribution of 10 into most likeley number
def get_pred(distributions):
    out = []
    for dist in distributions:
        out.append(np.argmax(dist))
    return np.array(out).reshape((len(distributions), 1))


# flattens pictures in a given set
def flatten(dataset):
    out = []
    for data in dataset:
        out.append(data.reshape((len(data[0])**2)))
    return np.array(out)


mnist = tf.keras.datasets.mnist
(xTrain, yTrain), (xTest, yTest) = mnist.load_data()
yTrainEncoded = binary_y(yTrain)
xTrain = xTrain / np.max(xTrain)
xTrain = flatten(xTrain)


# weights
np.random.seed(1)
syn0 = 2 * np.random.random((784, 128)) - 1
syn1 = 2 * np.random.random((128, 128)) - 1
syn2 = 2 * np.random.random((128, 10)) - 1

errors = []

for epoch in range(10):
    l0 = xTrain
    l1 = nonlin(np.dot(l0, syn0))
    l2 = nonlin(np.dot(l1, syn1))
    l3 = nonlin(np.dot(l2, syn2))

    l3_error = yTrainEncoded - l3

    deviations = [1 for x in (yTrain.reshape(
        60000, 1) - get_pred(l3)) if x != 0]
    errors.append(sum(deviations))

    l3_delta = l3_error * nonlin(l3, deriv=True)
    l2_error = l3_delta.dot(syn2.T)
    l2_delta = l2_error * nonlin(l2, deriv=True)
    l1_error = l2_delta.dot(syn1.T)
    l1_delta = l1_error * nonlin(l1, deriv=True)

    syn2 += l2.T.dot(l3_delta)
    syn1 += l1.T.dot(l2_delta)
    syn0 += l0.T.dot(l1_delta)

plt.plot(errors)
plt.show()

我非常感谢您提供有关如何解决此类问题的建议:)

0 个答案:

没有答案