是否可以在scikit-learn的DecisionTreeRegressor中检索每个叶子中的列车行ID?

时间:2015-11-05 13:31:59

标签: python scikit-learn decision-tree

目前,我可以检索我的训练样本中每个节点的ID,我的测试样本的每一行最有可能属于该节点:

tree.tree_.apply(np.array(X_test).astype(np.float32))其中X_test表示决策树的输入。

但是,对于我种植的树的每片叶子,我想获得其中包含的训练样本的ID。这样我就知道哪个训练样本与一个测试输入最相似。

1 个答案:

答案 0 :(得分:1)

我最终使用“apply”函数到我的训练样本中来获取它所属的leaf_id。

def get_nearest_points(self, tr, input_train):
  inside_leaves = {}
  tmp = tr.tree_.apply(np.array(input_train).astype(np.float32))
  leaves_list = set(tmp)

  for leaf in leaves_list:
    inside_leaves[leaf] = [idx for idx, elt in enumerate(tmp) if elt == leaf]
  return inside_leaves

inside_leaves现在是一个字典,其中包含每个leaf_id,包含此叶子中涉及的行的列表。