我试图在matlab中用3个卷积层实现一个简单的完全卷积网络(它只是为了学习和理解FCN), 我使用Matlab过滤函数(如imfilter,filter2和...)和梯度下降作为学习算法。 我使用了Sigmoid或ReLu作为激活函数, 但问题是它不能学习我使用的功能或学习率。
到目前为止我的实施:
clear
clc
%% image read and convert to gray scale
RGBimg0 = imread('0.jpg');
RGBimg1 = imread('1.jpg');
RGBimg00 = imread('00.jpg');
RGBimg11 = imread('11.jpg');
RGBimg000 = imread('000.jpg');
RGBimg111 = imread('111.jpg');
img = {double(rgb2gray(RGBimg0)) double(rgb2gray(RGBimg1)) ...
double(rgb2gray(RGBimg00)) double(rgb2gray(RGBimg11)) ...
double(rgb2gray(RGBimg000)) double(rgb2gray(RGBimg111))};
img{1} = imresize(img{1},[256 256]);
img{2} = imresize(img{2},[256 256]);
img{3} = imresize(img{3},[256 256]);
img{4} = imresize(img{4},[256 256]);
img{5} = imresize(img{5},[256 256]);
img{6} = imresize(img{6},[256 256]);
d0 = double(rgb2gray(imread('d0.jpg')));
d00 = double(rgb2gray(imread('d00.jpg')));
d000 = double(rgb2gray(imread('d000.jpg')));
d1 = double(rgb2gray(imread('d1.jpg')));
d11 = double(rgb2gray(imread('d11.jpg')));
d111 = double(rgb2gray(imread('d111.jpg')));
desired = {d0 d1 d00 d11 d000 d111};
desired{1} = imresize(desired{1},[256 256]);
desired{2} = imresize(desired{2},[256 256]);
desired{3} = imresize(desired{3},[256 256]);
desired{4} = imresize(desired{4},[256 256]);
desired{5} = imresize(desired{5},[256 256]);
desired{6} = imresize(desired{6},[256 256]);
%Normalise image (0 to 1)
for i=1 : 6
a = img{i}(:);
range = max(a) - min(a);
normm = (a - min(a)) / range;
img{i} = reshape(normm,256,256);
end
%Normalise desired (0 to 1)
for i=1 : 6
a = desired{i}(:);
range = max(a) - min(a);
normm = (a - min(a)) / range;
desired{i} = reshape(normm,256,256);
end
%% Convolution initialize
%filter definitions (randiom 3x3, range 0 to 1)
f11 = (-1 + (1+1)*rand(3));
f12 = (-1 + (1+1)*rand(3));
f13 = (-1 + (1+1)*rand(3));
f211 = (-1 + (1+1)*rand(3));
f212 = (-1 + (1+1)*rand(3));
f213 = (-1 + (1+1)*rand(3));
f221 = (-1 + (1+1)*rand(3));
f222 = (-1 + (1+1)*rand(3));
f223 = (-1 + (1+1)*rand(3));
f231 = (-1 + (1+1)*rand(3));
f232 = (-1 + (1+1)*rand(3));
f233 = (-1 + (1+1)*rand(3));
f311 = (-1 + (1+1)*rand(3));
f312 = (-1 + (1+1)*rand(3));
f313 = (-1 + (1+1)*rand(3));
f321 = (-1 + (1+1)*rand(3));
f322 = (-1 + (1+1)*rand(3));
f323 = (-1 + (1+1)*rand(3));
f331 = (-1 + (1+1)*rand(3));
f332 = (-1 + (1+1)*rand(3));
f333 = (-1 + (1+1)*rand(3));
f41 = (-1 + (1+1)*rand(3));
f42 = (-1 + (1+1)*rand(3));
f43 = (-1 + (1+1)*rand(3));
imgOut = zeros(256,256,6);
learnRate = 0.00000000001;
reluSlop = 0.9;
maxIt = 2000000;
err = zeros(35,1);
ETotal = zeros(maxIt,1);
%% Main Loop
for iteration = 1 : maxIt
for imgCount = 1 : 1 %size(img,2)
%% Covolution
% convolve 1
activeMap11 = imfilter(img{imgCount},f11,0);
activeMap12 = imfilter(img{imgCount},f12,0);
activeMap13 = imfilter(img{imgCount},f13,0);
% relu 1
activeMap11 = Relu(activeMap11, reluSlop, false);
activeMap12 = Relu(activeMap12, reluSlop, false);
activeMap13 = Relu(activeMap13, reluSlop, false);
% ---------------------------
% convolve 2
activeMap21 = imfilter(activeMap11,f211,0) + imfilter(activeMap12,f221,0) + imfilter(activeMap13,f231,0);
activeMap22 = imfilter(activeMap11,f212,0) + imfilter(activeMap12,f222,0) + imfilter(activeMap13,f232,0);
activeMap23 = imfilter(activeMap11,f213,0) + imfilter(activeMap12,f223,0) + imfilter(activeMap13,f233,0);
% relu 2
activeMap21 = Relu(activeMap21, reluSlop, false);
activeMap22 = Relu(activeMap22, reluSlop, false);
activeMap23 = Relu(activeMap23, reluSlop, false);
% ---------------------------
% convolve 3
activeMap31 = imfilter(activeMap21,f311,0) + imfilter(activeMap22,f321,0) + imfilter(activeMap23,f331,0);
activeMap32 = imfilter(activeMap21,f312,0) + imfilter(activeMap22,f322,0) + imfilter(activeMap23,f332,0);
activeMap33 = imfilter(activeMap21,f313,0) + imfilter(activeMap22,f323,0) + imfilter(activeMap23,f333,0);
% relu 3
activeMap31 = Relu(activeMap31, reluSlop, false);
activeMap32 = Relu(activeMap32, reluSlop, false);
activeMap33 = Relu(activeMap33, reluSlop, false);
% ---------------------------
% convolve 4
activeMap4 = imfilter(activeMap31,f41,0) + imfilter(activeMap32,f42,0) + imfilter(activeMap33,f43,0);
% relu 4
activeMap4 = Relu(activeMap4, reluSlop, false);
imgOut(:,:,imgCount) = activeMap4;
%% Backpropagation
errMat = 0.5 .* ((desired{imgCount} - activeMap4).^2);
err(imgCount) = sum(errMat(:));
%filter deltas----------------------------------------------
error = desired{imgCount} - activeMap4;
delta_activeMap4 = error .* Relu(activeMap4, reluSlop, true);
delta_activeMap31 = imfilter(delta_activeMap4,f41,0);
delta_activeMap32 = imfilter(delta_activeMap4,f42,0);
delta_activeMap33 = imfilter(delta_activeMap4,f43,0);
delta_activeMap31 = delta_activeMap31 .* Relu(activeMap31, reluSlop, true);
delta_activeMap32 = delta_activeMap32 .* Relu(activeMap32, reluSlop, true);
delta_activeMap33 = delta_activeMap33 .* Relu(activeMap33, reluSlop, true);
delta_activeMap21 = imfilter(delta_activeMap31,f311,0) + imfilter(delta_activeMap32,f312,0) + imfilter(delta_activeMap33,f313,0);
delta_activeMap22 = imfilter(delta_activeMap31,f321,0) + imfilter(delta_activeMap32,f322,0) + imfilter(delta_activeMap33,f323,0);
delta_activeMap23 = imfilter(delta_activeMap31,f331,0) + imfilter(delta_activeMap32,f332,0) + imfilter(delta_activeMap33,f333,0);
delta_activeMap21 = delta_activeMap21 .* Relu(activeMap21, reluSlop, false);
delta_activeMap22 = delta_activeMap22 .* Relu(activeMap22, reluSlop, false);
delta_activeMap23 = delta_activeMap23 .* Relu(activeMap23, reluSlop, false);
delta_activeMap11 = imfilter(delta_activeMap21,f211,0) + imfilter(delta_activeMap22,f212,0) + imfilter(delta_activeMap23,f213,0);
delta_activeMap12 = imfilter(delta_activeMap21,f221,0) + imfilter(delta_activeMap22,f222,0) + imfilter(delta_activeMap23,f223,0);
delta_activeMap13 = imfilter(delta_activeMap21,f231,0) + imfilter(delta_activeMap22,f232,0) + imfilter(delta_activeMap23,f233,0);
delta_activeMap11 = delta_activeMap11 .* Relu(activeMap11, reluSlop, false);
delta_activeMap12 = delta_activeMap12 .* Relu(activeMap12, reluSlop, false);
delta_activeMap13 = delta_activeMap13 .* Relu(activeMap13, reluSlop, false);
activeMap31 = padarray(activeMap31,[1 1]);
activeMap32 = padarray(activeMap32,[1 1]);
activeMap33 = padarray(activeMap33,[1 1]);
activeMap21 = padarray(activeMap21,[1 1]);
activeMap22 = padarray(activeMap22,[1 1]);
activeMap23 = padarray(activeMap23,[1 1]);
activeMap11 = padarray(activeMap11,[1 1]);
activeMap12 = padarray(activeMap12,[1 1]);
activeMap13 = padarray(activeMap13,[1 1]);
paddedImg1 = padarray(img{imgCount},[1 1]);
paddedImg2 = padarray(img{imgCount},[1 1]);
paddedImg3 = padarray(img{imgCount},[1 1]);
%Weight update
f41 = f41 + (learnRate .* (filterDelta4(activeMap31, delta_activeMap4)));
f42 = f41 + (learnRate .* (filterDelta4(activeMap32, delta_activeMap4)));
f43 = f41 + (learnRate .* (filterDelta4(activeMap33, delta_activeMap4)));
f311 = f311 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap31)));
f321 = f321 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap31)));
f331 = f331 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap31)));
f312 = f312 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap32)));
f322 = f322 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap32)));
f332 = f332 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap32)));
f313 = f313 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap33)));
f323 = f323 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap33)));
f333 = f333 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap33)));
f211 = f211 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap21)));
f221 = f221 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap21)));
f231 = f231 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap21)));
f212 = f212 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap22)));
f222 = f222 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap22)));
f232 = f232 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap22)));
f213 = f213 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap23)));
f223 = f223 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap23)));
f233 = f233 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap23)));
f11 = f11 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap11)));
f12 = f12 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap12)));
f13 = f13 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap13)));
fprintf('[%f]-', err(imgCount));
end
fprintf('%i \n', iteration);
ETotal(iteration) = err(1);
imshow([imgOut(:,:,1)],[])
end
plot(ETotal);
正如您所看到的,我的学习率非常小,因为它的滤镜渐变非常大。