我想要scikit-learn中的决策树(DecisionTreeClassifier)中从根节点到给定节点(由我提供)的决策路径(即规则集)。 clf.decision_path
指定样本要经过的节点,这可能有助于获取样本后遵循的规则集,但是如何获取到树中特定节点的规则集?
答案 0 :(得分:3)
iris dataset
的节点的决策规则:from sklearn.datasets import load_iris
from sklearn import tree
import graphviz
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
#this will create an iris.pdf file with the rule path
graph.render("iris")
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)
# The decision estimator has an attribute called tree_ which stores the entire
# tree structure and allows access to low level attributes. The binary tree
# tree_ is represented as a number of parallel arrays. The i-th element of each
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
# Some of the arrays only apply to either leaves or split nodes, resp. In this
# case the values of nodes of the other type are arbitrary!
#
# Among those arrays, we have:
# - left_child, id of the left child of the node
# - right_child, id of the right child of the node
# - feature, feature used for splitting the node
# - threshold, threshold value at the node
n_nodes = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature = estimator.tree_.feature
threshold = estimator.tree_.threshold
# The tree structure can be traversed to compute various properties such
# as the depth of each node and whether or not it is a leaf.
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
# If we have a test node
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaves[node_id] = True
print("The binary tree structure has %s nodes and has "
"the following tree structure:"
% n_nodes)
for i in range(n_nodes):
if is_leaves[i]:
print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
else:
print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
"node %s."
% (node_depth[i] * "\t",
i,
children_left[i],
feature[i],
threshold[i],
children_right[i],
))
print()
# First let's retrieve the decision path of each sample. The decision_path
# method allows to retrieve the node indicator functions. A non zero element of
# indicator matrix at the position (i, j) indicates that the sample i goes
# through the node j.
node_indicator = estimator.decision_path(X_test)
# Similarly, we can also have the leaves ids reached by each sample.
leave_id = estimator.apply(X_test)
# Now, it's possible to get the tests that were used to predict a sample or
# a group of samples. First, let's make it for the sample.
# HERE IS WHAT YOU WANT
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
if leave_id[sample_id] == node_id: # <-- changed != to ==
#continue # <-- comment out
print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
else: # < -- added else to iterate through decision nodes
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
threshold_sign,
threshold[node_id]))
Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011920929)
decision id node 2 : (X[0, 2] (= 5.1) > 4.949999809265137)
leaf node 4 reached, no decision here
答案 1 :(得分:2)
如果向None
中的out_file
提供export_graphviz
,则可以得到树的字符串表示形式。
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
string_data = tree.export_graphviz(clf,
out_file=None)
print(string_data)
#Output
digraph Tree {
node [shape=box] ;
0 [label="petal length (cm) <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa"] ;
1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="petal width (cm) <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="petal length (cm) <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor"] ;
2 -> 3 ;
4 [label="petal width (cm) <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]\nclass = versicolor"] ;
3 -> 4 ;
5 [label="gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]\nclass = versicolor"] ;
4 -> 5 ;
6 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica"] ;
4 -> 6 ;
7 [label="petal width (cm) <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]\nclass = virginica"] ;
3 -> 7 ;
8 [label="gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]\nclass = virginica"] ;
7 -> 8 ;
9 [label="sepal length (cm) <= 6.95\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]\nclass = versicolor"] ;
7 -> 9 ;
10 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor"] ;
9 -> 10 ;
11 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica"] ;
9 -> 11 ;
12 [label="petal length (cm) <= 4.85\ngini = 0.043\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica"] ;
2 -> 12 ;
13 [label="sepal length (cm) <= 5.95\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]\nclass = virginica"] ;
12 -> 13 ;
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor"] ;
13 -> 14 ;
15 [label="gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]\nclass = virginica"] ;
13 -> 15 ;
16 [label="gini = 0.0\nsamples = 43\nvalue = [0, 0, 43]\nclass = virginica"] ;
12 -> 16 ;
}
这将满足您的需求。然后,您可以轻松编写一个程序来解析该程序以根据需要进行处理。