二叉搜索树中K个最小元素的总和

时间:2017-05-04 12:14:46

标签: java algorithm recursion binary-search-tree

有人可以建议,这下面的代码有什么问题可以找到BST中k个最小元素的总和?它返回树中所有节点的总和。

public int findSum(Node root, int k){
            int count = 0;
            return findSumRec(root, k, count);
        }

        private int findSumRec(Node root, int k, int count) {

            if(root == null)
                return 0;
            if(count > k)
                return 0;

            int sum = findSumRec(root.left, k, count);
            if(count >= k)
                return sum;

            sum += root.data;
            count++;

            if(count >= k)
                return sum;
            return sum + findSumRec(root.right, k, count);
            }

2 个答案:

答案 0 :(得分:0)

嗯,Java是一种按值传递的语言,因此在一个方法调用中更改count的值并不会更改调用它的方法的count的值。

假设在递归过程中的某个时刻,当k5count4时,您正在调用:

        int sum = findSumRec(root.left, 5, 4);

假设此次调用会将传递给count的{​​{1}}增加到5,并返回一些sum

现在回到调用findSumRec(root.left, 5, 4)的方法,然后检查:

        if(4 >= 5)
            return sum;

即使最近返回的递归调用将count提升到5,这意味着您应该完成,调用者仍会将count视为4(因为它&# 39; s不是相同的count变量),因此树的遍历将继续,直到您访问树的所有节点并将它们全部加起来。

您必须使用一些可变实例来修改count

例如,您可以使用具有单个元素的数组:

public int findSum(Node root, int k){
    int[] count = {0};
    return findSumRec(root, k, count);
}

private int findSumRec(Node root, int k, int[] count) {
    ...
    change each count to count[0]
    ...
}

编辑:我刚用我建议的修正测试了你的代码,它确实有效。

public int findSum(Node root, int k) {
    int[] count = {0};
    return findSumRec(root, k, count);
}

private int findSumRec(Node root, int k, int[] count) {

    if(root == null)
        return 0;
    if(count[0] > k)
        return 0;

    int sum = findSumRec(root.left, k, count);
    if(count[0] >= k)
        return sum;

    sum += root.val;
    count[0]++;

    if(count[0] >= k)
        return sum;
    return sum + findSumRec(root.right, k, count);
}

关键是对findSumRec的所有递归方法调用必须共享count变量的值并能够更改它。当count是传递给方法的原始变量时,这是不可能的,因为每个方法都获得该变量的不同副本。

使用数组是一种选择。另一种选择是使用包含方法的类的成员变量,而不是将其作为参数传递。这样它仍然可以是int

答案 1 :(得分:-1)

我认为你正在寻找这样的东西,这是我在java中的代码

import java.io.*;

public class Main {

  public static class Node {
    int val;
    Node left;
    Node right;
    public Node(int data) {
      this.val = data;
      this.left = null;
      this.right = null;
    }
  }

  public static int sum_ans = 0;

  public static int count_k_smallest_sum(Node root, int count, int k)
  {
    if(count>=k)
    {
      return 100005;
    }
    if(root.left == null && root.right == null)
    {
      sum_ans += root.val;
      return count+1;
    }
    int cnt = count;
    if(root.left != null)
    {
      cnt = count_k_smallest_sum(root.left, cnt, k);
    }
    if(cnt >= k)
    {
      return 100005;
    }
    sum_ans += root.val;
    cnt++;

    if(cnt >= k)
    {
      return 100005;
    }

    if(root.right != null)
    {
      cnt = count_k_smallest_sum(root.right, cnt, k);
    }

    return cnt;
  }

  public static void main(String args[]) throws Exception {
    Node root = new Node(10);
    Node left1 = new Node(5);
    Node right1 = new Node(11);
    Node left2 = new Node(3);
    Node right2 =new Node(12);
    Node left3 = new Node(2);
    Node right3 = new Node(7);
    Node right4 = new Node(4);
    root.left = left1;
    root.right = right1;
    right1.right = right2;
    left1.left = left2;
    left1.right = right3;
    left2.left = left3;
    left2.right = right4;
    int cnt = count_k_smallest_sum(root, 0, 3);
    System.out.println(sum_ans);
  }
}

请参阅方法中的代码逻辑 - count_k_smallest_sum。

希望这有帮助!