我一直在尝试使用mlpack在C ++中实现随机森林。 我的数据具有一些分类特征。 我一直在尝试使用mlpack的DatasetInfo,但没有成功。
代码如下:
#include "pch.h"
#include <iostream>
using namespace arma;
using namespace mlpack;
using namespace mlpack::tree;
using namespace mlpack::cv;
int main()
{
cout << "[SAMPLE:BEGIN]";
// (1) Load the dataset
cout << "\nLoading dataset...";
mat dataset;
data::DatasetInfo di;
di = data::DatasetInfo(0);
bool loaded = data::Load("data/final.csv", dataset, di);
if (!loaded)
return -1;
di.Type(0) = data::Datatype::numeric;
di.Type(1) = data::Datatype::categorical;
di.Type(2) = data::Datatype::categorical;
Row<size_t> labels;
// Extract the labels from the last dimension of the training set
//labels = conv_to<Row<size_t>>::from(dataset.row(dataset.n_rows - 1));
loaded = data::Load("data/labels.csv", labels);
// Remove the labels from the training set
//dataset.shed_row(dataset.n_rows - 1);
// (2) Training
cout << "\nTraining...";
const size_t numClasses = 2;
const size_t minimumLeafSize = 5;
const size_t numTrees = 10;
RandomForest<GiniGain, RandomDimensionSelect> rf;
rf = RandomForest<GiniGain, RandomDimensionSelect>(dataset, di, labels,
numClasses, numTrees, minimumLeafSize);
Row<size_t> predictions;
rf.Classify(dataset, predictions);
const size_t correct = arma::accu(predictions == labels);
cout << "\nTraining Accuracy: " << (double(correct) / double(labels.n_elem));
//Save the model
cout << "\nSaving model...";
mlpack::data::Save("mymodel.xml", "model", rf, false);
//Load the model
cout << "\nLoading model...";
mlpack::data::Load("mymodel.xml", "model", rf);
// (6) Classify a new sample
cout << "\nClassifying a new sample...";
mat sample("67.00,5812,901");
mat probabilities;
rf.Classify(sample, predictions, probabilities);
u64 result = predictions.at(0);
cout << "\nClassification result: " << result << " , Probabilities: " <<
probabilities.at(0) << "/" << probabilities.at(1);
cout << "\n[SAMPLE:END]\n";
return 0;
}
我已将数据拆分为一个文件,该文件具有三个功能,其他文件具有标签。 final.csv具有功能。
1548.0,5964,812
其中第一列为数字,其他两列应视为类别。
labels.csv具有0
和1
形式的标签。
当我尝试训练机器时,此示例由于读取访问冲突异常而崩溃。
我认为我在尝试指定DatasetInfo的方式上做错了事。
任何人都可以指出这里出了什么问题或指向一些示例,在其中可以了解如何使用DatasetInfo。
谢谢。