有人可以建议,这下面的代码有什么问题可以找到BST中k个最小元素的总和?它返回树中所有节点的总和。
public int findSum(Node root, int k){
int count = 0;
return findSumRec(root, k, count);
}
private int findSumRec(Node root, int k, int count) {
if(root == null)
return 0;
if(count > k)
return 0;
int sum = findSumRec(root.left, k, count);
if(count >= k)
return sum;
sum += root.data;
count++;
if(count >= k)
return sum;
return sum + findSumRec(root.right, k, count);
}
答案 0 :(得分:0)
嗯,Java是一种按值传递的语言,因此在一个方法调用中更改count
的值并不会更改调用它的方法的count
的值。
假设在递归过程中的某个时刻,当k
为5
且count
为4
时,您正在调用:
int sum = findSumRec(root.left, 5, 4);
假设此次调用会将传递给count
的{{1}}增加到5
,并返回一些sum
。
现在回到调用findSumRec(root.left, 5, 4)
的方法,然后检查:
if(4 >= 5)
return sum;
即使最近返回的递归调用将count
提升到5
,这意味着您应该完成,调用者仍会将count
视为4
(因为它&# 39; s不是相同的count
变量),因此树的遍历将继续,直到您访问树的所有节点并将它们全部加起来。
您必须使用一些可变实例来修改count
。
例如,您可以使用具有单个元素的数组:
public int findSum(Node root, int k){
int[] count = {0};
return findSumRec(root, k, count);
}
和
private int findSumRec(Node root, int k, int[] count) {
...
change each count to count[0]
...
}
编辑:我刚用我建议的修正测试了你的代码,它确实有效。
public int findSum(Node root, int k) {
int[] count = {0};
return findSumRec(root, k, count);
}
private int findSumRec(Node root, int k, int[] count) {
if(root == null)
return 0;
if(count[0] > k)
return 0;
int sum = findSumRec(root.left, k, count);
if(count[0] >= k)
return sum;
sum += root.val;
count[0]++;
if(count[0] >= k)
return sum;
return sum + findSumRec(root.right, k, count);
}
关键是对findSumRec
的所有递归方法调用必须共享count
变量的值并能够更改它。当count
是传递给方法的原始变量时,这是不可能的,因为每个方法都获得该变量的不同副本。
使用数组是一种选择。另一种选择是使用包含方法的类的成员变量,而不是将其作为参数传递。这样它仍然可以是int
。
答案 1 :(得分:-1)
我认为你正在寻找这样的东西,这是我在java中的代码
import java.io.*;
public class Main {
public static class Node {
int val;
Node left;
Node right;
public Node(int data) {
this.val = data;
this.left = null;
this.right = null;
}
}
public static int sum_ans = 0;
public static int count_k_smallest_sum(Node root, int count, int k)
{
if(count>=k)
{
return 100005;
}
if(root.left == null && root.right == null)
{
sum_ans += root.val;
return count+1;
}
int cnt = count;
if(root.left != null)
{
cnt = count_k_smallest_sum(root.left, cnt, k);
}
if(cnt >= k)
{
return 100005;
}
sum_ans += root.val;
cnt++;
if(cnt >= k)
{
return 100005;
}
if(root.right != null)
{
cnt = count_k_smallest_sum(root.right, cnt, k);
}
return cnt;
}
public static void main(String args[]) throws Exception {
Node root = new Node(10);
Node left1 = new Node(5);
Node right1 = new Node(11);
Node left2 = new Node(3);
Node right2 =new Node(12);
Node left3 = new Node(2);
Node right3 = new Node(7);
Node right4 = new Node(4);
root.left = left1;
root.right = right1;
right1.right = right2;
left1.left = left2;
left1.right = right3;
left2.left = left3;
left2.right = right4;
int cnt = count_k_smallest_sum(root, 0, 3);
System.out.println(sum_ans);
}
}
请参阅方法中的代码逻辑 - count_k_smallest_sum。
希望这有帮助!