如何从scikit-learn中导出决策树,并根据树中的类权重来计算值?

时间:2019-04-30 13:52:29

标签: python-3.x scikit-learn graphviz

我正在使用DecisionTreeClassifier研究数据样本,并且样本不平衡(很多0,几个1)。因此,我使用class_weight={0: 1, 1: 5}参数训练模型。

训练后,我需要导出树以说明研究。我使用sklearn.tree.export_graphviz,但它不计算样本中的实际对象数。它显示“值= [zeroes_count * 1,ones_count * 5]”,因此数字使我和我正在显示结果的任何人感到困惑。

是否有一种方法可以将类的权重计算出来?

现在我将树导出到字符串:

g = export_graphviz(clf, None)

,并将所有用逗号分隔的数字对替换为新值:

def recalc(s):
    values = s[0][1:-1].split(',')
    return f"[{values[0]}, {int(values[1]) // weight_for_1}]"

regex = re.compile('\[\d+\, \d+\]')
g = re.sub(regex, recalc, g)

之后,我将字符串写入文件:

with open('tree.dot', 'w') as fw:
    fw.write(g)

0 个答案:

没有答案