Dlib二进制CIFAR-10数据集阅读,深度学习

时间:2018-02-03 15:44:05

标签: c++ database deep-learning data-conversion dlib

我想使用示例MNIST代码在DLIB中对图像进行分类: http://dlib.net/dnn_introduction_ex.cpp.html

但我的数据集将是CIFAR 10,二进制1: http://www.cs.toronto.edu/~kriz/cifar.html

我不知道如何让dlib可读和训练。

我找到了一个将它转换为结构的二进制阅读器,但我仍然无法使它成为可训练的:/ https://github.com/wichtounet/cifar-10

我的代码:

#include "stdafx.h"
#include <cifar\cifar10_reader.hpp>
#include <dlib/dnn.h>
#include <iostream>
#include <dlib/data_io.h>


using namespace std;

using namespace dlib;

int main(int argc, char** argv) try
{

auto dataset = cifar::read_dataset<std::vector, std::vector, uint8_t, uint8_t>();

using net_type = loss_multiclass_log<
    fc<10,
    relu<fc<84,
    relu<fc<120,
    max_pool<2, 2, 2, 2, relu<con<16, 5, 5, 1, 1,
    max_pool<2, 2, 2, 2, relu<con<6, 5, 5, 1, 1,
    input<matrix<unsigned char>>
    >>>>>>>>>>>>;

net_type net;

dnn_trainer<net_type> trainer(net);
trainer.set_learning_rate(0.01);
trainer.set_min_learning_rate(0.00001);
trainer.set_mini_batch_size(128);
trainer.be_verbose();

trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));

trainer.train(dataset.training_images, dataset.training_labels);

net.clean();
serialize("cifar.dat") << net;

std::vector<unsigned long> predicted_labels = net(dataset.training_images);
int num_right = 0;
int num_wrong = 0;

for (size_t i = 0; i < dataset.training_images.size(); ++i)
{
    if (predicted_labels[i] == dataset.training_labels[i])
        ++num_right;
    else
        ++num_wrong;

}
cout << "training num_right: " << num_right << endl;
cout << "training num_wrong: " << num_wrong << endl;
cout << "training accuracy:  " << num_right / (double)(num_right + num_wrong) << endl;


predicted_labels = net(dataset.test_images);
num_right = 0;
num_wrong = 0;
for (size_t i = 0; i < dataset.test_images.size(); ++i)
{
    if (predicted_labels[i] == dataset.test_labels[i])
        ++num_right;
    else
        ++num_wrong;

}
cout << "testing num_right: " << num_right << endl;
cout << "testing num_wrong: " << num_wrong << endl;
cout << "testing accuracy:  " << num_right / (double)(num_right + num_wrong) << endl;



net_to_xml(net, "cif.xml");
}
catch (std::exception& e)
 {
     cout << e.what() << endl;
 } 

我只将原始矢量替换为training_images到dataset.training_images。但是这个程序没有编译,因为在:

中过载
trainer.train(dataset.training_images, dataset.training_labels);

我不确定dataset.training_images等是什么类型以及如何使用它。

0 个答案:

没有答案