我想使用示例MNIST代码在DLIB中对图像进行分类: http://dlib.net/dnn_introduction_ex.cpp.html
但我的数据集将是CIFAR 10,二进制1: http://www.cs.toronto.edu/~kriz/cifar.html
我不知道如何让dlib可读和训练。
我找到了一个将它转换为结构的二进制阅读器,但我仍然无法使它成为可训练的:/ https://github.com/wichtounet/cifar-10
我的代码:
#include "stdafx.h"
#include <cifar\cifar10_reader.hpp>
#include <dlib/dnn.h>
#include <iostream>
#include <dlib/data_io.h>
using namespace std;
using namespace dlib;
int main(int argc, char** argv) try
{
auto dataset = cifar::read_dataset<std::vector, std::vector, uint8_t, uint8_t>();
using net_type = loss_multiclass_log<
fc<10,
relu<fc<84,
relu<fc<120,
max_pool<2, 2, 2, 2, relu<con<16, 5, 5, 1, 1,
max_pool<2, 2, 2, 2, relu<con<6, 5, 5, 1, 1,
input<matrix<unsigned char>>
>>>>>>>>>>>>;
net_type net;
dnn_trainer<net_type> trainer(net);
trainer.set_learning_rate(0.01);
trainer.set_min_learning_rate(0.00001);
trainer.set_mini_batch_size(128);
trainer.be_verbose();
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
trainer.train(dataset.training_images, dataset.training_labels);
net.clean();
serialize("cifar.dat") << net;
std::vector<unsigned long> predicted_labels = net(dataset.training_images);
int num_right = 0;
int num_wrong = 0;
for (size_t i = 0; i < dataset.training_images.size(); ++i)
{
if (predicted_labels[i] == dataset.training_labels[i])
++num_right;
else
++num_wrong;
}
cout << "training num_right: " << num_right << endl;
cout << "training num_wrong: " << num_wrong << endl;
cout << "training accuracy: " << num_right / (double)(num_right + num_wrong) << endl;
predicted_labels = net(dataset.test_images);
num_right = 0;
num_wrong = 0;
for (size_t i = 0; i < dataset.test_images.size(); ++i)
{
if (predicted_labels[i] == dataset.test_labels[i])
++num_right;
else
++num_wrong;
}
cout << "testing num_right: " << num_right << endl;
cout << "testing num_wrong: " << num_wrong << endl;
cout << "testing accuracy: " << num_right / (double)(num_right + num_wrong) << endl;
net_to_xml(net, "cif.xml");
}
catch (std::exception& e)
{
cout << e.what() << endl;
}
我只将原始矢量替换为training_images到dataset.training_images。但是这个程序没有编译,因为在:
中过载trainer.train(dataset.training_images, dataset.training_labels);
我不确定dataset.training_images等是什么类型以及如何使用它。