我正在使用scikit-learn构建决策树模型,之后我想进行一些叶重写。基本上,我想更改特定叶节点的标签。
我在树叶上循环,并根据
tree.DecisionTreeClassifier.tree_
我可以得到
tree_.value
为了计算节点的标签。我是从here获得的。
我的问题是,是否可以以及如何强制更改决策树节点的标签?
目前,我尝试手动更改tree_.value
from sklearn import tree
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
df = pd.read_csv("voting.csv", header=0)
y = pd.DataFrame(df.target)
feature_names = []
for col in df.columns:
if col != 'target':
feature_names.append(col)
y = df.target
df = df.drop("target", 1)
thr = 0.9
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2)
clf = tree.DecisionTreeClassifier(min_samples_leaf=3)
clf.fit(X_train, y_train)
node_count = clf.tree_.node_count
class_label = 0
for index in range(node_count):
# check if it is a leaf
if clf.tree_.children_right[index] == -1 and clf.tree_.children_left[index] == -1:
# number of samples in the leaf (correctly classified and misclassified)
print("Values: ", clf.tree_.value[index])
# Finding node label
node_label = clf.classes_[np.argmax(clf.tree_.value[index])]
values = clf.tree_.value[index]
correct_samples = values[0][node_label]
misclassified_samples = np.sum(clf.tree_.value[index]) - correct_samples
# Change the label if number of misclassified samples is more than 0
if misclassified_samples > 0 and node_label != class_label:
clf.tree_.value[index][class_label] = clf.tree_.value[index][class_label] + correct_samples
print("New values: ", clf.tree_.value[index])
但这会导致更改两个值,即使正确分类也是如此。然后,节点的标签保持不变。例如,在操作之前:
Values: [[1. 2.]]
以及手术后:
New values: [[3. 4.]]
谢谢!