我试图做的是证明如果我再向CNN添加一层,准确性会更高。
代码如下所示。
此代码来自https://github.com/lhoang29/DigitRecognition/blob/master/cnnload.m
我正处于CNN的初级阶段,并尝试再扩展一层,包括 卷积和汇集阶段。我尝试了几种方法,但似乎没有用。有人可以告诉我如何再扩展一层吗?
三江源。以下是代码
主要功能代码:
clear all; close all; clc;
maxtrain = 10000;
iter = 10;
eta = 0.01;
%% Data Load
trlblid = fopen('train-labels-idx1-ubyte');
trimgid = fopen('train-images-idx3-ubyte');
tslblid = fopen('t10k-labels-idx1-ubyte');
tsimgid = fopen('t10k-images-idx3-ubyte');
% read train labels
fread(trlblid, 4);
numtrlbls = toint(fread(trlblid, 4));
trainlabels = fread(trlblid, numtrlbls);
% read train data
fread(trimgid, 4);
numtrimg = toint(fread(trimgid, 4));
trimgh = toint(fread(trimgid, 4));
trimgw = toint(fread(trimgid, 4));
trainimages = permute(reshape(fread(trimgid,trimgh*trimgw*numtrimg),trimgh,trimgw,numtrimg), [2 1 3]);
% read test labels
fread(tslblid, 4);
numtslbls = toint(fread(tslblid, 4));
testlabels = fread(tslblid, numtslbls);
% read test data
fread(tsimgid, 4);
numtsimg = toint(fread(tsimgid, 4));
tsimgh = toint(fread(tsimgid, 4));
tsimgw = toint(fread(tsimgid, 4));
testimages = permute(reshape(fread(tsimgid, tsimgh*tsimgw*numtsimg),tsimgh,tsimgw,numtsimg), [2 1 3]);
%% CNN Training
[missimages, misslabels] = cnntrain(trainlabels,trainimages,testlabels,testimages,maxtrain,iter,eta);
%% CNN Testing
showmiss(missimages,misslabels,testimages,testlabels,25,2);
培训守则: 功能[missimages,misslabels] = cnntrain(trainlabels,trainimages,testlabels,testimages,maxtrain,iter,eta)
fn = 5; % number of kernels for layer 1
ks = 5; % size of kernel
[h,w,n] = size(trainimages);
n = min(n,maxtrain);
% normalize data to [-1,1] range
nitrain = (trainimages / 255) * 2 - 1;
nitest = (testimages / 255) * 2 - 1;
% train with backprop
h1 = h-ks+1;
w1 = w-ks+1;
A1 = zeros(h1,w1,fn);
h2 = h1/2;
w2 = w1/2;
I2 = zeros(h2,w2,fn);
A2 = zeros(h2,w2,fn);
A3 = zeros(10,1);
% kernels for layer 1
W1 = randn(ks,ks,fn) * .01;
B1 = ones(1,fn);
% scale parameter and bias for layer 2
S2 = randn(1,fn) * .01;
B2 = ones(1,fn);
% weights and bias parameters for fully-connected output layer
W3 = randn(h2,w2,fn,10) * .01;
B3 = ones(10,1);
% true outputs
Y = eye(10)*2-1;
for it=1:iter
err = 0;
for im=1:n
%------------ FORWARD PROP ------------%
% Layer 1: convolution with bias followed by sigmoidal squashing
for fm=1:fn
A1(:,:,fm) = convn(nitrain(:,:,im),W1(end:-1:1,end:-1:1,fm),'valid') + B1(fm);
end
Z1 = tanh(A1);
% Layer 2: average/subsample with scaling and bias
for fm=1:fn
I2(:,:,fm) = avgpool(Z1(:,:,fm));
A2(:,:,fm) = I2(:,:,fm) * S2(fm) + B2(fm);
end
Z2 = tanh(A2);
% Layer 3: fully connected
for cl=1:10
A3(cl) = convn(Z2,W3(end:-1:1,end:-1:1,end:-1:1,cl),'valid') + B3(cl);
end
Z3 = tanh(A3); % Final output
err = err + .5 * norm(Z3 - Y(:,trainlabels(im)+1),2)^2;
%------------ BACK PROP ------------%
% Compute error at output layer
Del3 = (1 - Z3.^2) .* (Z3 - Y(:,trainlabels(im)+1));
% Compute error at layer 2
Del2 = zeros(size(Z2));
for cl=1:10
Del2 = Del2 + Del3(cl) * W3(:,:,:,cl);
end
Del2 = Del2 .* (1 - Z2.^2);
% Compute error at layer 1
Del1 = zeros(size(Z1));
for fm=1:fn
Del1(:,:,fm) = (S2(fm)/4)*(1 - Z1(:,:,fm).^2);
for ih=1:h1
for iw=1:w1
Del1(ih,iw,fm) = Del1(ih,iw,fm) * Del2(floor((ih+1)/2),floor((iw+1)/2),fm);
end
end
end
% Update bias at layer 3
DB3 = Del3; % gradient w.r.t bias
B3 = B3 - eta*DB3;
% Update weights at layer 3
for cl=1:10
DW3 = DB3(cl) * Z2; % gradient w.r.t weights
W3(:,:,:,cl) = W3(:,:,:,cl) - eta * DW3;
end
% Update scale and bias parameters at layer 2
for fm=1:fn
DS2 = convn(Del2(:,:,fm),I2(end:-1:1,end:-1:1,fm),'valid');
S2(fm) = S2(fm) - eta * DS2;
DB2 = sum(sum(Del2(:,:,fm)));
B2(fm) = B2(fm) - eta * DB2;
end
% Update kernel weights and bias parameters at layer 1
for fm=1:fn
DW1 = convn(nitrain(:,:,im),Del1(end:-1:1,end:-1:1,fm),'valid');
W1(:,:,fm) = W1(:,:,fm) - eta * DW1;
DB1 = sum(sum(Del1(:,:,fm)));
B1(fm) = B1(fm) - eta * DB1;
end
end
disp(['Error: ' num2str(err) ' at iteration ' num2str(it)]);
end
miss = 0;
numtest=size(testimages,3);
missimages = zeros(1,numtest);
misslabels = zeros(1,numtest);
for im=1:numtest
for fm=1:fn
A1(:,:,fm) = convn(nitest(:,:,im),W1(end:-1:1,end:-1:1,fm),'valid') + B1(fm);
end
Z1 = tanh(A1);
% Layer 2: average/subsample with scaling and bias
for fm=1:fn
I2(:,:,fm) = avgpool(Z1(:,:,fm));
A2(:,:,fm) = I2(:,:,fm) * S2(fm) + B2(fm);
end
Z2 = tanh(A2);
% Layer 3: fully connected
for cl=1:10
A3(cl) = convn(Z2,W3(end:-1:1,end:-1:1,end:-1:1,cl),'valid') + B3(cl);
end
Z3 = tanh(A3); % Final output
[pm,pl] = max(Z3);
if pl ~= testlabels(im)+1
miss = miss + 1;
missimages(miss) = im;
misslabels(miss) = pl - 1;
end
end
disp(['Miss: ' num2str(miss) ' out of ' num2str(numtest)]);
end
function [pr] = avgpool(img)
pr = zeros(size(img)/2);
for r=1:2:size(img,1)
for c=1:2:size(img,2)
pr((r+1)/2,(c+1)/2) = (img(r,c)+img(r+1,c)+img(r,c+1)+img(r+1,c+1))/4;
end
end
end
显示准确性的代码
function [] = showmiss(missim,misslab,testimages,testlabels,numshow,numpages)
nummiss = nnz(missim);
page = 1;
showsize = floor(sqrt(numshow));
for f=1:numshow:nummiss
figure(floor(f/numshow) + 1);
for m=f:min(nummiss,f+numshow-1)
subplot(showsize,showsize,m-f+1);
imshow(testimages(:,:,missim(m)));
title(strcat(num2str(testlabels(missim(m))), ':', num2str(misslab(m))));
end
page = page + 1;
if page > numpages
break;
end
end
end
功能toint
function [x] = toint(b)
x = b(1)*16777216 + b(2)*65536 + b(3)*256 + b(4);
end