我正在使用DecisionTreeClassifier研究数据样本,并且样本不平衡(很多0,几个1)。因此,我使用class_weight={0: 1, 1: 5}
参数训练模型。
训练后,我需要导出树以说明研究。我使用sklearn.tree.export_graphviz
,但它不计算样本中的实际对象数。它显示“值= [zeroes_count * 1,ones_count * 5]”,因此数字使我和我正在显示结果的任何人感到困惑。
是否有一种方法可以将类的权重计算出来?
现在我将树导出到字符串:
g = export_graphviz(clf, None)
,并将所有用逗号分隔的数字对替换为新值:
def recalc(s):
values = s[0][1:-1].split(',')
return f"[{values[0]}, {int(values[1]) // weight_for_1}]"
regex = re.compile('\[\d+\, \d+\]')
g = re.sub(regex, recalc, g)
之后,我将字符串写入文件:
with open('tree.dot', 'w') as fw:
fw.write(g)