Caffe2:如何在C ++中加载和使用MNIST教程模型

时间:2017-09-14 09:46:00

标签: caffe2

我正在努力在受过训练的MNIST caffe2教程模型的C ++结果中进行复制。我所做的是我稍微修改了MNIST python教程(代码可用here),在python方面一切正常。

如果我运行mnist.py,我会得到两个" .pb"具有网络定义和初始化的文件。如果我在python端加载这个网并用DB中的一些图像提供它,那么我将得到正确的预测:

timg = np.fromfile('test_img.dat', dtype=np.uint8).reshape([28,28])
workspace.FeedBlob('data', (timg/256.).reshape([1,1,28,28]).astype(np.float32))
workspace.RunNet(net_def.name)
workspace.FetchBlob('softmax')
array([[  1.23242417e-05,   6.76146897e-07,   9.01260137e-06,
      1.60285403e-04,   9.54966026e-07,   6.82772861e-06,
      2.20508967e-09,   9.99059498e-01,   2.71651220e-06,
      7.47664250e-04]], dtype=float32)

所以很确定测试图像是' 7' (这是正确的。)

但是我无法从C ++中获得相同的结果。我已经了解了其他项目(herehere)是如何完成的,并提出了以下建议:

C ++ net initialization

QByteArray img_bytes; // where the raw image bytes are kept (size 28x28)
caffe2::NetDef init_net, predict_net;
caffe2::TensorCPU input;
// predictor and it's input/output vectors
std::unique_ptr<caffe2::Predictor> predictor;
caffe2::Predictor::TensorVector input_vec;
caffe2::Predictor::TensorVector output_vec;
...
QFile f("mnist_init_net.pb");

...
auto barr = f.readAll();
if (! init_net.ParseFromArray(barr.data(), barr.size())) {

...
f.setFileName("mnist_predict_net.pb");

...
barr = f.readAll();
if (! predict_net.ParseFromArray(barr.data(), barr.size())) {

...
predictor.reset(new caffe2::Predictor(init_net, predict_net));
input.Resize(std::vector<int>{{1, 1, IMG_H, IMG_W}});
input_vec.resize(1, &input);

此初始化运行没有问题。由于部署网络没有扩展和转换为浮动,我必须这样做(与上面的python代码段相同),我按如下方式执行:

float* data = input.mutable_data<float>();
for (int i = 0; i < img_bytes.size(); ++i)
    *data++ = float(img_bytes[i])/256.f;

最后我喂了预测器:

if (! predictor->run(input_vec, &output_vec) || output_vec.size() < 1
                                             || output_vec[0]->size() != 10)
...

我在同一档案中得到的结果是&#39; 7&#39;是17%(不是99.9%),其余类别约为5-10%。

现在我被困了,我不知道问题出在哪里,所以我很感激任何提示/提示/指示。

1 个答案:

答案 0 :(得分:2)

事实证明,我使用Caffe2没有问题,但我的使用 预处理。由于img_bytes是一个基本类型为char的QByteArray,因此默认情况下(在gcc中)char是一个带符号的类型,此转换和缩放:

*data++ = float(img_bytes[i])/256.f;

导致一些负值(而不是范围[0,1]中的浮点数)。正确的版本是:

*data++ = static_cast<unsigned char>(img_bytes[i])/256.f