答案 0 :(得分:9)
如果您只想要每个样本的叶子,您可以使用
clf.apply(iris.data)
数组([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5, 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, 5,5,14,5,5,5,5,5,5,10,5,5,5,5,5,10,5, 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,16,16, 16,16,16,16,6,16,16,16,16,16,16,16,16,16,16,16,16, 8,16,16,16,16,16,16,15,16,16,11,16,16,16,8,8,16, 16,16,15,16,16,16,16,16,16,16,16,16,16,16])
如果要获取每个节点的所有样本,可以使用
计算所有决策路径dec_paths = clf.decision_path(iris.data)
然后遍历决策路径,将它们转换为toarray()
的数组,并检查它们是否属于某个节点。所有内容都存储在defaultdict
中,其中键是节点编号,值是样本编号。
for d, dec in enumerate(dec_paths):
for i in range(clf.tree_.node_count):
if dec.toarray()[0][i] == 1:
samples[i].append(d)
完整代码
import sklearn.datasets
import sklearn.tree
import collections
clf = sklearn.tree.DecisionTreeClassifier(random_state=42)
iris = sklearn.datasets.load_iris()
clf = clf.fit(iris.data, iris.target)
samples = collections.defaultdict(list)
dec_paths = clf.decision_path(iris.data)
for d, dec in enumerate(dec_paths):
for i in range(clf.tree_.node_count):
if dec.toarray()[0][i] == 1:
samples[i].append(d)
<强>输出强>
print(samples[13])
[70,126,138]