生成JAVA的决策树上的StackOverflowError

时间:2016-12-30 16:51:18

标签: java stack-overflow decision-tree id3

我试图编写生成决策树的ID3算法,但是当我运行代码时,我得到了StackOverflowError。 在调试时,我注意到当属性下降到4(从最初9开始)时循环开始。 树生成的代码如下。我调用的所有功能都正常工作,已经过测试。 但是,错误代码表明问题出在另一个使用流的函数中,但它已经过单独测试 而且我知道它运作正常。请记住,我正在处理随机数据,因此函数有时会抛出 错误,有时不是。我在它下面发布了错误代码,但是熵函数和信息增益工作。

这是TreeNode结构:

public class TreeNode {
    List<Patient> samples;
    List<TreeNode> children;
    TreeNode parent;
    Integer attribute;
    String attributeValue;
    String className;

    public TreeNode(List<Patient> samples, List<TreeNode> children, TreeNode parent, Integer attribute,
            String attributeValue, String className) {
        this.samples = samples;
        this.children = children;
        this.parent = parent;
        this.attribute = attribute;
        this.attributeValue = attributeValue;
        this.className = className;
    }
}

这就是抛出错误的代码:

public TreeNode id3(List<Patient> patients, List<Integer> attributes, TreeNode root) {
        boolean isLeaf = patients.stream().collect(Collectors.groupingBy(i -> i.className)).keySet().size() == 1;
        if (isLeaf) {
            root.setClassName(patients.get(0).className);
            return root;
        }
        if (attributes.size() == 0) {
            root.setClassName(mostCommonClass(patients));
            return root;
        }
        int bestAttribute = maxInformationGainAttribute(patients, attributes);
        Set<String> attributeValues = attributeValues(patients, bestAttribute);
        for (String value : attributeValues) {
            List<Patient> branch = patients.stream().filter(i -> i.patientData[bestAttribute].equals(value))
                    .collect(Collectors.toList());

            TreeNode child = new TreeNode(branch, new ArrayList<>(), root, bestAttribute, value, null);

            if (branch.isEmpty()) {
                child.setClassName(mostCommonClass(patients));
                root.addChild(new TreeNode(child));
            } else {
                List<Integer> newAttributes = new ArrayList<>();
                newAttributes.addAll(attributes);
                newAttributes.remove(new Integer(bestAttribute));
                root.addChild(new TreeNode(id3(branch, newAttributes, child)));
            }
        }
        return root;
    }

这些是其他功能:

public static double entropy(List<Patient> patients) {
        double entropy = 0.0;
        double recurP = (double) patients.stream().filter(i -> i.className.equals("recurrence-events")).count()
                / (double) patients.size();
        double noRecurP = (double) patients.stream().filter(i -> i.className.equals("no-recurrence-events")).count()
                / (double) patients.size();
        entropy -= (recurP * (recurP > 0 ? Math.log(recurP) : 0 / Math.log(2))
                + noRecurP * (noRecurP > 0 ? Math.log(noRecurP) : 0 / Math.log(2)));
        return entropy;
    }



public static double informationGain(List<Patient> patients, int attribute) {
        double informationGain = entropy(patients);
        Map<String, List<Patient>> patientsGroupedByAttribute = patients.stream()
                .collect(Collectors.groupingBy(i -> i.patientData[attribute]));
        List<List<Patient>> subsets = new ArrayList<>();
        for (String i : patientsGroupedByAttribute.keySet()) {
            subsets.add(patientsGroupedByAttribute.get(i));
        }

        for (List<Patient> lp : subsets) {
            informationGain -= proportion(lp, patients) * entropy(lp);
        }
        return informationGain;
    }


private static int maxInformationGainAttribute(List<Patient> patients, List<Integer> attributes) {
        int maxAttribute = 0;
        double maxInformationGain = 0;
        for (int i : attributes) {
            if (informationGain(patients, i) > maxInformationGain) {
                maxAttribute = i;
                maxInformationGain = informationGain(patients, i);
            }
        }
        return maxAttribute;
    }

例外情况:

Exception in thread "main" java.lang.StackOverflowError
    at java.util.stream.ReferencePipeline$2$1.accept(Unknown Source)
    at java.util.ArrayList$ArrayListSpliterator.forEachRemaining(Unknown Source)
    at java.util.stream.AbstractPipeline.copyInto(Unknown Source)
    at java.util.stream.AbstractPipeline.wrapAndCopyInto(Unknown Source)
    at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(Unknown Source)
    at java.util.stream.AbstractPipeline.evaluate(Unknown Source)
    at java.util.stream.LongPipeline.reduce(Unknown Source)
    at java.util.stream.LongPipeline.sum(Unknown Source)
    at java.util.stream.ReferencePipeline.count(Unknown Source)
    at Patient.entropy(Patient.java:39)
    at Patient.informationGain(Patient.java:67)
    at Patient.maxInformationGainAttribute(Patient.java:85)
    at Patient.id3(Patient.java:109)

1 个答案:

答案 0 :(得分:0)

该行:

root.addChild(new TreeNode(id3(branch, newAttributes, child)));

每次方法递归时都会调用,这会导致堆栈溢出。这告诉我你的逻辑中存在一些错误,其中没有任何结束递归的“基本情况”,即返回根目录。我不太了解所需的行为或起始数据以查明出现了什么问题,但我首先要使用调试器逐步执行代码,并确保方法中的逻辑表现出您的预期。我知道这不是一个好的答案,但它是一个起点,希望有帮助或其他人会使用更具体的解决方案。