深度学习 - 我的训练集Theano的成本函数增加

时间:2016-05-30 11:02:02

标签: python machine-learning computer-vision neural-network

我正在尝试训练一个非常简单的神经网络,它有2个隐藏层(例如),每个神经网络有5个神经元。

出于某种原因,我的成本函数(选择为交叉熵,但并不重要)对于某个数据集总是在增加。

这是我的代码 -

import theano_net as N
from load import mnist
import theano
from theano import tensor as T
import numpy as np
import cv2

def floatX(X):
    return np.asarray(X,dtype=theano.config.floatX)

def init_weights(shape):
    return theano.shared(floatX(np.random.randn(*shape)*0.01))

def appendGD(params, grad, step):
    updates = []
    for p,g in zip(params,grad):
        updates.append([p, p - (g * step)])
    return updates

def model(X,w1,w2,wo):
    h0 = X
    z1 = T.dot(h0, w1.T)  ## n on m1
    h1 = T.nnet.sigmoid(z1)

    z2 = T.dot(h1, w2.T)  ## n on m2
    h2 = T.nnet.sigmoid(z2)

    zo = T.dot(h2, wo.T)

    return T.nnet.softmax(zo)

numOfTrainPics = 4872
numOfTestPics = 382
numOfPixels = 40000
numOfLabels = 6

trX = np.zeros((numOfTrainPics,numOfPixels))
trY = np.zeros((numOfTrainPics,numOfLabels))
teX = np.zeros((numOfTestPics,numOfPixels))
teY = np.zeros((numOfTestPics,numOfLabels))

for i in range(1,4873): #generate trX and trY
    img = cv2.imread('C:\\Users\\Oria\\Desktop\\Semester B\\Computer Vision Cornel 2016\\Train\\Train\\%s.jpg' %(i))
    img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
    img = np.reshape(img,(1,numOfPixels))
    trX[i-1,:] = img
    if(i < 1330):
        trY[i-1,0] = 1
    if(i > 1329)&(i < 1817):
        trY[i-1,1] = 1
    if(i > 1816)&(i < 2389):
        trY[i-1,2] = 1
    if(i > 2388)&(i < 3043):
        trY[i-1,3] = 1
    if(i > 3042)&(i < 4438):
        trY[i-1,4] = 1
    if(i > 4437)&(i < 4873):
        trY[i-1,5] = 1
for i in range(1,383):
    img = cv2.imread('C:\\Users\\Oria\\Desktop\\Semester B\\Computer Vision Cornel 2016\\Test\\Test\\%s.jpg' %(i))
    img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
    img = np.reshape(img,(1,numOfPixels))
    teX[i-1,:] = img
    if(i < 59):
        teY[i-1,0] = 1
    if(i > 58)&(i < 120):
        teY[i-1,1] = 1
    if(i > 119)&(i < 185):
        teY[i-1,2] = 1
    if(i > 184)&(i < 261):
        teY[i-1,3] = 1
    if(i > 260)&(i < 326):
        teY[i-1,4] = 1
    if(i > 325)&(i < 383):
        teY[i-1,5] = 1
print "matrices generated"
###

x = T.fmatrix()
y = T.fmatrix()
step = 0.1
m1 = 5
m2 = 5
w1 = init_weights((m1, numOfPixels))
w2 = init_weights((m2, m1))
wo = init_weights((numOfLabels, m2))

temp = model(x, w1, w2, wo)

predictions = T.argmax(temp, axis= 1)

cost = T.mean(T.nnet.categorical_crossentropy(temp, y))

params = [w1, w2, wo]
gradient = T.grad(cost=cost, wrt = params)
update = appendGD(params, gradient, step)

train = theano.function(inputs = [x,y], outputs = cost, updates = update, allow_input_downcast=True)
predict = theano.function(inputs=[x],outputs=[predictions],allow_input_downcast=True)

for i in range(10000):
    for start, end in zip(range(0,len(trX),241),range(241,len(trX),241)):
        cost = train(trX[start:end], trY[start:end])
    print cost

对于我正在加载此代码的trX,我的成本函数总是在增加。

然而,当我运行相同的代码但trX和trY来自MNIST数据集时,它正常工作并且成本函数正在减少。

我不明白为什么会这样,以及如何解决它。

当我看到mnist数据集的第一行trX(第一张图片)时,可能是一个线索,它是一个非常稀疏的矩阵,非零元素都在0和1之间。

[ 0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.01176471
  0.07058824  0.07058824  0.07058824  0.49411765  0.53333333  0.68627451
  0.10196078  0.65098039  1.          0.96862745  0.49803922  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.11764706  0.14117647  0.36862745
  0.60392157  0.66666667  0.99215686  0.99215686  0.99215686  0.99215686
  0.99215686  0.88235294  0.6745098   0.99215686  0.94901961  0.76470588
  0.25098039  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.19215686
  0.93333333  0.99215686  0.99215686  0.99215686  0.99215686  0.99215686
  0.99215686  0.99215686  0.99215686  0.98431373  0.36470588  0.32156863
  0.32156863  0.21960784  0.15294118  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.07058824  0.85882353  0.99215686  0.99215686  0.99215686
  0.99215686  0.99215686  0.77647059  0.71372549  0.96862745  0.94509804
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.31372549  0.61176471
  0.41960784  0.99215686  0.99215686  0.80392157  0.04313725  0.
  0.16862745  0.60392157  0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.05490196  0.00392157  0.60392157  0.99215686  0.35294118  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.54509804  0.99215686  0.74509804  0.00784314
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.04313725  0.74509804  0.99215686
  0.2745098   0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.1372549
  0.94509804  0.88235294  0.62745098  0.42352941  0.00392157  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.31764706  0.94117647  0.99215686  0.99215686  0.46666667  0.09803922
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.17647059  0.72941176  0.99215686  0.99215686
  0.58823529  0.10588235  0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.0627451   0.36470588
  0.98823529  0.99215686  0.73333333  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.97647059  0.99215686  0.97647059  0.25098039  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.18039216  0.50980392
  0.71764706  0.99215686  0.99215686  0.81176471  0.00784314  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.15294118  0.58039216  0.89803922
  0.99215686  0.99215686  0.99215686  0.98039216  0.71372549  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.09411765  0.44705882  0.86666667  0.99215686
  0.99215686  0.99215686  0.99215686  0.78823529  0.30588235  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.09019608  0.25882353  0.83529412  0.99215686  0.99215686
  0.99215686  0.99215686  0.77647059  0.31764706  0.00784314  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.07058824  0.67058824  0.85882353  0.99215686  0.99215686  0.99215686
  0.99215686  0.76470588  0.31372549  0.03529412  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.21568627  0.6745098   0.88627451  0.99215686  0.99215686  0.99215686
  0.99215686  0.95686275  0.52156863  0.04313725  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.53333333  0.99215686  0.99215686  0.99215686  0.83137255
  0.52941176  0.51764706  0.0627451   0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.        ]

当我查看我的trX但是我的代码中加载了数据库时,trX [0]大部分都是非零,其元素在0到255之间。

我只想在我的数据库上训练神经网络。它应该不会太困难,并且代码被证明可以与MNIST一起使用。我只是不明白如何正确加载我的数据集。

1 个答案:

答案 0 :(得分:1)

您必须将输入数据标准化为[0,1]或[-1,1]范围,因为您已经在MNIST数据集中看到了这一点。 没有规范化,训练神经网络要困难得多。

您可以通过减去数据集的平均值,除以标准偏差,或者只进行最小 - 最大标准化来实现,这将给出[0,1]范围。

对于每通道8位图像,您可以将每个像素除以255以获得[0,1]范围,这通常足以获得NN训练,但并非总是如此。