MATLAB中的卷积神经网络。再添一层

时间:2017-11-29 12:57:05

标签: matlab deep-learning conv-neural-network

我试图做的是证明如果我再向CNN添加一层,准确性会更高。

代码如下所示。

此代码来自https://github.com/lhoang29/DigitRecognition/blob/master/cnnload.m

我正处于CNN的初级阶段,并尝试再扩展一层,包括 卷积和汇集阶段。我尝试了几种方法,但似乎没有用。有人可以告诉我如何再扩展一层吗?

三江源。以下是代码

主要功能代码:

clear all; close all; clc;

maxtrain = 10000;
iter = 10;
eta = 0.01;

%% Data Load

trlblid = fopen('train-labels-idx1-ubyte');
trimgid = fopen('train-images-idx3-ubyte');
tslblid = fopen('t10k-labels-idx1-ubyte');
tsimgid = fopen('t10k-images-idx3-ubyte');

% read train labels
fread(trlblid, 4);
numtrlbls = toint(fread(trlblid, 4));
trainlabels = fread(trlblid, numtrlbls);

% read train data
fread(trimgid, 4);
numtrimg = toint(fread(trimgid, 4));
trimgh = toint(fread(trimgid, 4));
trimgw = toint(fread(trimgid, 4));
trainimages = permute(reshape(fread(trimgid,trimgh*trimgw*numtrimg),trimgh,trimgw,numtrimg), [2 1 3]);

% read test labels
fread(tslblid, 4);
numtslbls = toint(fread(tslblid, 4));
testlabels = fread(tslblid, numtslbls);

% read test data
fread(tsimgid, 4);
numtsimg = toint(fread(tsimgid, 4));
tsimgh = toint(fread(tsimgid, 4));
tsimgw = toint(fread(tsimgid, 4));
testimages = permute(reshape(fread(tsimgid, tsimgh*tsimgw*numtsimg),tsimgh,tsimgw,numtsimg), [2 1 3]);

%% CNN Training

[missimages, misslabels] = cnntrain(trainlabels,trainimages,testlabels,testimages,maxtrain,iter,eta);

%% CNN Testing

showmiss(missimages,misslabels,testimages,testlabels,25,2);

培训守则:     功能[missimages,misslabels] = cnntrain(trainlabels,trainimages,testlabels,testimages,maxtrain,iter,eta)

fn = 5; % number of kernels for layer 1
ks = 5; % size of kernel

[h,w,n] = size(trainimages);
n = min(n,maxtrain);

% normalize data to [-1,1] range
nitrain = (trainimages / 255) * 2 - 1;
nitest = (testimages / 255) * 2 - 1;

% train with backprop
h1 = h-ks+1;
w1 = w-ks+1;
A1 = zeros(h1,w1,fn);

h2 = h1/2;
w2 = w1/2;
I2 = zeros(h2,w2,fn);
A2 = zeros(h2,w2,fn);

A3 = zeros(10,1);

% kernels for layer 1
W1 = randn(ks,ks,fn) * .01;
B1 = ones(1,fn);

% scale parameter and bias for layer 2
S2 = randn(1,fn) * .01;
B2 = ones(1,fn);

% weights and bias parameters for fully-connected output layer
W3 = randn(h2,w2,fn,10) * .01;
B3 = ones(10,1);

% true outputs
Y = eye(10)*2-1;

for it=1:iter
    err = 0;
    for im=1:n
        %------------ FORWARD PROP ------------%
        % Layer 1: convolution with bias followed by sigmoidal squashing
        for fm=1:fn
            A1(:,:,fm) = convn(nitrain(:,:,im),W1(end:-1:1,end:-1:1,fm),'valid') + B1(fm);
        end
        Z1 = tanh(A1);

        % Layer 2: average/subsample with scaling and bias
        for fm=1:fn
            I2(:,:,fm) = avgpool(Z1(:,:,fm));
            A2(:,:,fm) = I2(:,:,fm) * S2(fm) + B2(fm);
        end
        Z2 = tanh(A2);

        % Layer 3: fully connected
        for cl=1:10
            A3(cl) = convn(Z2,W3(end:-1:1,end:-1:1,end:-1:1,cl),'valid') + B3(cl);
        end
        Z3 = tanh(A3); % Final output
        err = err + .5 * norm(Z3 - Y(:,trainlabels(im)+1),2)^2;

        %------------ BACK PROP ------------%
        % Compute error at output layer
        Del3 = (1 - Z3.^2) .* (Z3 - Y(:,trainlabels(im)+1));

        % Compute error at layer 2
        Del2 = zeros(size(Z2));
        for cl=1:10
            Del2 = Del2 + Del3(cl) * W3(:,:,:,cl);
        end
        Del2 = Del2 .* (1 - Z2.^2);

        % Compute error at layer 1
        Del1 = zeros(size(Z1));
        for fm=1:fn
            Del1(:,:,fm) = (S2(fm)/4)*(1 - Z1(:,:,fm).^2);
            for ih=1:h1
                for iw=1:w1
                    Del1(ih,iw,fm) = Del1(ih,iw,fm) * Del2(floor((ih+1)/2),floor((iw+1)/2),fm);
                end
            end
        end

        % Update bias at layer 3
        DB3 = Del3; % gradient w.r.t bias
        B3 = B3 - eta*DB3;

        % Update weights at layer 3
        for cl=1:10
            DW3 = DB3(cl) * Z2; % gradient w.r.t weights
            W3(:,:,:,cl) = W3(:,:,:,cl) - eta * DW3;
        end

        % Update scale and bias parameters at layer 2
        for fm=1:fn
            DS2 = convn(Del2(:,:,fm),I2(end:-1:1,end:-1:1,fm),'valid');
            S2(fm) = S2(fm) - eta * DS2;

            DB2 = sum(sum(Del2(:,:,fm)));
            B2(fm) = B2(fm) - eta * DB2;
        end

        % Update kernel weights and bias parameters at layer 1
        for fm=1:fn
            DW1 = convn(nitrain(:,:,im),Del1(end:-1:1,end:-1:1,fm),'valid');
            W1(:,:,fm) = W1(:,:,fm) - eta * DW1;

            DB1 = sum(sum(Del1(:,:,fm)));
            B1(fm) = B1(fm) - eta * DB1;
        end
    end
    disp(['Error: ' num2str(err) ' at iteration ' num2str(it)]);
end

miss = 0;
numtest=size(testimages,3);
missimages = zeros(1,numtest);
misslabels = zeros(1,numtest);
for im=1:numtest
    for fm=1:fn
        A1(:,:,fm) = convn(nitest(:,:,im),W1(end:-1:1,end:-1:1,fm),'valid') + B1(fm);
    end
    Z1 = tanh(A1);

    % Layer 2: average/subsample with scaling and bias
    for fm=1:fn
        I2(:,:,fm) = avgpool(Z1(:,:,fm));
        A2(:,:,fm) = I2(:,:,fm) * S2(fm) + B2(fm);
    end
    Z2 = tanh(A2);

    % Layer 3: fully connected
    for cl=1:10
        A3(cl) = convn(Z2,W3(end:-1:1,end:-1:1,end:-1:1,cl),'valid') + B3(cl);
    end
    Z3 = tanh(A3); % Final output

    [pm,pl] = max(Z3);
    if pl ~= testlabels(im)+1
        miss = miss + 1;
        missimages(miss) = im;
        misslabels(miss) = pl - 1;
    end
end
disp(['Miss: ' num2str(miss) ' out of ' num2str(numtest)]);

end

function [pr] = avgpool(img)
    pr = zeros(size(img)/2);
    for r=1:2:size(img,1)
        for c=1:2:size(img,2)
            pr((r+1)/2,(c+1)/2) = (img(r,c)+img(r+1,c)+img(r,c+1)+img(r+1,c+1))/4;
        end
    end
end

显示准确性的代码

function [] = showmiss(missim,misslab,testimages,testlabels,numshow,numpages)
    nummiss = nnz(missim);

    page = 1;
    showsize = floor(sqrt(numshow));
    for f=1:numshow:nummiss
        figure(floor(f/numshow) + 1);
        for m=f:min(nummiss,f+numshow-1)
            subplot(showsize,showsize,m-f+1);
            imshow(testimages(:,:,missim(m)));
            title(strcat(num2str(testlabels(missim(m))), ':', num2str(misslab(m))));
        end
        page = page + 1;
        if page > numpages
            break;
        end
    end

end

功能toint

function [x] = toint(b)
    x = b(1)*16777216 + b(2)*65536 + b(3)*256 + b(4);
end

0 个答案:

没有答案