我正在C ++ 17中实现决策树。我将数据集中的每条记录存储为一个元组,并包含一个包含每个元组的向量。
我需要为在元组字段中找到的每个唯一值创建元组的子集。
我了解访问元组中的元素是在编译时完成的,所以我想知道是否可以使用任何技巧,或者对于这个特定项目我是否应该完全放弃元组。
此外,在拆分树时,我需要在将来的迭代中删除最佳属性。我以为使用index_sequence可能有效,但是任何在线执行的实例都是连续的数字(例如1、2、3、4、5),而不是随机的子集(例如1、3、5)
zip_recordset函数采用一个元组向量(一个RecordSet)并创建一个向量元组。我以为这会很有用,但是我仍然需要在运行时访问哪个向量。
#ifndef __DECISION_TREE
#define __DECISION_TREE
#include <algorithm>
#include <set>
#include <map>
#include "common.h"
namespace ml::dt {
struct Node {
bool leaf;
size_t attr;
size_t label;
std::vector<Node> children;
};
template<typename T>
class DecisionTree {
using RecordSet = typename T::RecordSet;
using Record = typename T::Record;
Node head;
template<typename U>
double entropy(std::vector<U> attrVec) {
std::map<U, std::size_t> valMap {};
for(auto i=0; i<attrVec.size(); ++i) {
if(valMap.count(attrVec[i]) == 0) {
valMap[attrVec[i]] = 1;
} else {
valMap[attrVec[i]]++;
}
}
int totalSize = 0;
for(auto& [key, val] : valMap) {
totalSize += val;
}
double entropy = 0;
for(auto& [key, val] : valMap) {
double p = (static_cast<double>(val)/totalSize);
entropy += p * std::log2(p);
}
return entropy;
}
template<auto target>
Node learn_impl(RecordSet trainSet, std::set<std::size_t> attrList) {
int pos = 0;
int neg = 0;
for(auto& rec : trainSet) {
if(std::get<target>(rec) == 0) {
neg++;
} else {
pos++;
}
}
// If all examples are positive, return +
if(neg == 0) {
return Node {true, 0, 1, {}};
}
// If all examples are negative, return -
if(pos == 0) {
return Node {true, 0, 0, {}};
}
// If attr is empty, return label Mode(examples)
if(attrList.empty()) {
size_t label = 0;
if(neg < pos)
label = 1;
return Node {true, 0, label, {}};
}
// Get Best Attribute
size_t bestAttr = 0;
std::vector<double> entropy_values {};
auto trainSetVec = T::zip_recordset(trainSet);
auto clVec = std::get<T::size-1>(trainSetVec);
std::apply([&](auto& ...v) {
(entropy_values.emplace_back(entropy(v)),...);
}, trainSetVec);
bestAttr = std::distance(entropy_values.begin(), std::min_element(entropy_values.begin(), entropy_values.end()));
// For every value of attr
// examples_new = values with attr
// if examples_new empty
// leaf node = mode(examples)
// else
// learn_impl<target>(examples_new, attrList - bestAttr)
}
public:
DecisionTree() = default;
template<auto target>
void learn(RecordSet trainSet) {
std::set<std::size_t> attrList {};
for(auto i=0; i<T::size; ++i) {
attrList.emplace(i);
}
head = learn_impl<target>(trainSet, attrList);
}
};
}
#endif // __DECISION_TREE
有什么办法吗?
添加了更长的代码示例。 扩展了有关访问元组元素的说明。