如何使Caffe Matlab包装器适用于在Mnist上训练的网络?

时间:2015-05-19 16:43:40

标签: matlab caffe matcaffe

我在http://caffe.berkeleyvision.org/gathered/examples/mnist.html

之后成功地在mnist数据库上训练了我的Caffe网络

现在我想使用Matlab包装器使用我自己的图像测试网络。

因此在“matcaffe.m”中加载文件“lenet.prototxt”,它不用于培训,但似乎适合测试。它引用的输入大小为28 x 28像素:

name: "LeNet"
input: "data"
input_dim: 64
input_dim: 1
input_dim: 28
input_dim: 28
layer {
name: "conv1"
type: "Convolution"
bottom: "data"
top: "conv1"

因此,我相应地在“matcaffe.m”中改编了“prepare_image”功能。它现在看起来像这样:

% ------------------------------------------------------------------------
function images = prepare_image(im)
IMAGE_DIM = 28;
% resize to fixed input size
im = rgb2gray(im);
im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear');
im = single(im);
images = zeros(1,1,IMAGE_DIM,IMAGE_DIM);
images(1,1,:,:) = im;
images = single(images);
%-------------------------------------------------------------

这会将输入图像转换为[1 x 1 x 28 x 28],4dim,灰度图像。但Matlab仍在抱怨:

Error using caffe
MatCaffe input size does not match the input size of the
network
Error in matcaffe_myModel_mnist (line 76)
scores = caffe('forward', input_data);

有人有经验,可以根据自己的数据测试受过训练的网络吗?

2 个答案:

答案 0 :(得分:3)

您遇到该错误(输入大小不匹配)的原因是网络原型文件需要一批64个图像。行

input_dim: 64
input_dim: 1
input_dim: 28
input_dim: 28

意味着网络预计会有一批64个灰度,28个28个图像。如果保持所有MATLAB代码相同并将第一行更改为

input_dim: 1

你的问题应该消失。

答案 1 :(得分:3)

最后我找到了完整的解决方案: 这是如何使用matca的matcaffe.m(Matlab包装器)预测自己输入图像的数字

  1. 在“matcaffe.m”中:必须引用文件“caffe-master / examples / mnist / lenet.prototxt”
  2. 根据mprat的指示调整文件“lenet.prototxt”:将条目input_dim更改为input_dim: 1
  3. 使用matcaffe.m中的子功能“prepare_image”的后续改编:
  4. (输入可以是任何大小的rgb图像)

    function image = prepare_image(im)
    
    IMAGE_DIM = 28;
    
    % If input image is too big , is rgb and of type uint8:
    % -> resize to fixed input size, single channel, type float
    
    im = rgb2gray(im);
    im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear');
    im = single(im);
    
    % Caffe needs a 4D input matrix which has single precision
    % Data has to be scaled by 1/256 = 0.00390625 (like during training)
    % In the second last line the image is beeing transposed!
    images = zeros(1,1,IMAGE_DIM,IMAGE_DIM);
    images(1,1,:,:) = 0.00390625*im';
    images = single(images);