mlpack中的分类功能

时间:2019-05-06 08:29:38

标签: c++11 mlpack

我一直在尝试使用mlpack在C ++中实现随机森林。 我的数据具有一些分类特征。 我一直在尝试使用mlpack的DatasetInfo,但没有成功。

代码如下:

#include "pch.h"
#include <iostream>
using namespace arma;
using namespace mlpack;
using namespace mlpack::tree;
using namespace mlpack::cv;

int main()
{
    cout << "[SAMPLE:BEGIN]";

    // (1) Load the dataset
    cout << "\nLoading dataset...";

    mat dataset;
    data::DatasetInfo di;
    di = data::DatasetInfo(0);
    bool loaded = data::Load("data/final.csv", dataset, di);
    if (!loaded)
        return -1;

    di.Type(0) = data::Datatype::numeric;
    di.Type(1) = data::Datatype::categorical;
    di.Type(2) = data::Datatype::categorical;

    Row<size_t> labels;

    // Extract the labels from the last dimension of the training set
    //labels = conv_to<Row<size_t>>::from(dataset.row(dataset.n_rows - 1));
    loaded = data::Load("data/labels.csv", labels);

    // Remove the labels from the training set
    //dataset.shed_row(dataset.n_rows - 1);

    // (2) Training
    cout << "\nTraining...";
    const size_t numClasses = 2;
    const size_t minimumLeafSize = 5;
    const size_t numTrees = 10;

    RandomForest<GiniGain, RandomDimensionSelect> rf; 

    rf = RandomForest<GiniGain, RandomDimensionSelect>(dataset, di, labels,
        numClasses, numTrees, minimumLeafSize);

    Row<size_t> predictions;
    rf.Classify(dataset, predictions);

    const size_t correct = arma::accu(predictions == labels);

    cout << "\nTraining Accuracy: " << (double(correct) / double(labels.n_elem));


    //Save the model
    cout << "\nSaving model...";
    mlpack::data::Save("mymodel.xml", "model", rf, false);

    //Load the model
    cout << "\nLoading model...";
    mlpack::data::Load("mymodel.xml", "model", rf);

    // (6) Classify a new sample
    cout << "\nClassifying a new sample...";
    mat sample("67.00,5812,901");
    mat probabilities;
    rf.Classify(sample, predictions, probabilities);
    u64 result = predictions.at(0);
    cout << "\nClassification result: " << result << " , Probabilities: " <<
        probabilities.at(0) << "/" << probabilities.at(1);
    cout << "\n[SAMPLE:END]\n";
    return 0;
}

我已将数据拆分为一个文件,该文件具有三个功能,其他文件具有标签。 final.csv具有功能。

1548.0,5964,812

其中第一列为数字,其他两列应视为类别。

labels.csv具有01形式的标签。

当我尝试训练机器时,此示例由于读取访问冲突异常而崩溃。

我认为我在尝试指定DatasetInfo的方式上做错了事。

任何人都可以指出这里出了什么问题或指向一些示例,在其中可以了解如何使用DatasetInfo。

谢谢。

0 个答案:

没有答案