我的MNIST数据集的python代码显示了巨大的错误值,我缺少什么?

时间:2019-05-01 22:49:55

标签: python machine-learning mnist

在下面,您可以找到我的代码,这是我第一个用于自学机器学习和Python的认真代码。我试图从头开始编写代码,而不使用像NumPy这样的库。对于单个输入和输出,该代码有效,但是当涉及到实际数据集(在本例中为10个输出的784个输入)时,它将返回无穷大作为错误。我检查了我认为可能没有成功的所有问题。

该代码可能是一个肮脏的解决方案。我从研究Trask Github的代码开始,他的代码用于多种输入/输出作品,但是当我修改它以使用MNIST时,一切都变得疯狂。 有人可以看一下并帮助我知道我所缺少的是什么吗?赞赏。

for i in range (x_train.shape[0]):
    x_labels[i,x_label[i]]=1
def w_sum(a,b):
    assert(len(a) == len(b))
    output = 0
    for i in range(len(a)):
        output += (a[i] * b[i])
    return output

def neural_network(input1, weights):
    pred = vect_mat_mul(input1,weights)
    return pred

def vect_mat_mul(vect,matrix):
    output = np.zeros(10)
    for i in range(10):

        output[i] = w_sum(vect[0],matrix[:,i])

    return output
def outer_prod(a, b):
    out = np.zeros((len(a), len(b)))
    for i in range(len(a)):
        for j in range(len(b)):
            out[i][j] = a[i] * b[j]
    return out

(x_train,x_label),(t_test,t_label)=ks.datasets.mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],28*28)
weights=np.random.rand(784,10)
x_labels=np.zeros((x_train.shape[0],10))
alpha = 0.00001

for i in range(x_train.shape[0]):
    error = np.zeros(10)
    delta = np.zeros(10)
    for iter in range(50):
        x_train_to_NN = np.array([x_train[i]])
        pred = neural_network(x_train_to_NN, weights)
        for j in range(10):
            error[j] = (pred[j] - x_labels[i, j]) ** 2
            delta[j] = pred[j] - x_labels[i, j]
        weight_deltas = outer_prod(x_train[i], delta) #calculate the gradient
        for idx in range(784):
            for jdx in range(10):
                weights[idx][jdx] -= alpha * weight_deltas[idx][jdx] #update weight matrix

print('key=', i, '\n Error=', error, '\n Delta=', delta, '\n Prediction=', pred)

2 个答案:

答案 0 :(得分:0)

我终于找到了答案,那就是“ Gradient Clipping”。 问题是在计算梯度时,需要对其进行限制(标准化),以避免梯度爆炸。

答案 1 :(得分:-1)

我在这里看到很多错误。 使用tensorflow,pyTorch等NN库的主要好处之一是,它们可以为您处理漂亮而优美的线性代数部分。 例如,神经网络的所有权重都以特殊方式初始化,因此它们既不大于1,也不小于1,否则,梯度会消失或爆炸得太快。 此外,不清楚您在哪里计算梯度,更新成本函数等。要计算梯度,您需要转到日志空间并返回,以便再次避免会导致梯度爆炸的浮点错误(因此无限误差)。 :) 我建议您对理论部分有一个更好的了解,然后尝试分别实现每个部分。 干杯,