如何优化此递归回溯算法?

时间:2019-06-10 15:56:15

标签: java recursion permutation

我对递归回溯不太熟悉,但我想尝试一下。我编写了这段代码,查找所有整数的排列,这些排列范围从低到高(作为参数传递),总和等于给定int NumTotal。它适用于较小的数字,但我得到

  

异常在线程 “主要” java.lang.OutOfMemoryError:GC开销超出限制

表示较大的数字。

public static void findAllSolutionMethod(int NTotal, ArrayList<ArrayList<Integer>> solutions,
        ArrayList<Integer> currentSolution, int lower, int upper) {

    // success base case
    if (NTotal == 0) {
        ArrayList<Integer> copy = new ArrayList<Integer>();
        // creates deep copy 
        for (int i = 0; i < currentSolution.size(); i++) {
            copy.add(currentSolution.get(i));
        }
        // add to solutions arraylist
        solutions.add(copy);
        return;
    }

    // invalid number base case (number added too big)
    else if (NTotal < 0) {
        return;
    }

    else {
        // iterates through range of numbers
        for (int i = lower; i <= upper; i++) {
            currentSolution.add(i);
            findAllSolutionMethod(NTotal - i, solutions, currentSolution, lower, upper);
            currentSolution.remove(currentSolution.size() - 1);
        }
    }
}

有什么办法可以优化此代码,以使其不占用太多空间?

1 个答案:

答案 0 :(得分:0)

如果您将图像递归调用为树,这将非常容易。在纸上写下解决方案,制作 bfs ,您会看到优化方法。

public static List<int[]> findPermutations(int sum, int low, int high) {
    final Function<Node, int[]> getPath = node -> {
        int[] arr = new int[node.depth()];
        int i = arr.length - 1;

        while (node != null) {
            arr[i--] = node.val;
            node = node.parent;
        }

        return arr;
    };

    List<int[]> res = new LinkedList<>();
    Deque<Node> queue = new LinkedList<>();

    for (int i = low; i <= high; i++) {
        queue.clear();
        queue.add(new Node(low));

        while (!queue.isEmpty()) {
            Node node = queue.remove();

            if (node.sum == sum)
                res.add(getPath.apply(node));
            else {
                for (int j = low; j <= high; j++) {
                    if (node.sum + j <= sum)
                        queue.add(new Node(j, node.sum + j, node));
                    else
                        break;
                }
            }
        }
    }

    return res;
}

private static final class Node {

    private final int val;
    private final int sum;
    private final Node parent;

    public Node(int val) {
        this(val, val, null);
    }

    public Node(int val, int sum, Node parent) {
        this.val = val;
        this.sum = sum;
        this.parent = parent;
    }

    public int depth() {
        return parent == null ? 1 : (parent.depth() + 1);
    }

    @Override
    public String toString() {
        return val + " (" + sum + ')';
    }

}

演示:

findPermutations(4, 1, 3).forEach(path -> System.out.println(Arrays.toString(path)));

输出:

[1, 3]
[1, 1, 2]
[1, 2, 1]
[1, 1, 1, 1]
[2, 2]
[2, 1, 1]
[3, 1]