完美平方leetcode-带有记忆的递归解决方案

时间:2020-11-04 09:51:05

标签: java algorithm recursion dynamic-programming perfect-square

尝试解决递归和记事的this问题,但对于输入7168,我得到了错误的答案。

    public int numSquares(int n) {
        Map<Integer, Integer> memo = new HashMap();
        List<Integer> list = fillSquares(n, memo);
        if (list == null)
            return 1;
        return helper(list.size()-1, list, n, memo);
    }
    
    private int helper(int index, List<Integer> list, int left, Map<Integer, Integer> memo) {
        
        if (left == 0)
            return 0;
        if (left < 0 || index < 0)
            return Integer.MAX_VALUE-1;
        
        if (memo.containsKey(left)) {
            return memo.get(left);
        }
        
        int d1 = 1+helper(index, list, left-list.get(index), memo);
        int d2 = 1+helper(index-1, list, left-list.get(index),  memo);
        int d3 = helper(index-1, list, left, memo);
        
        int d = Math.min(Math.min(d1,d2), d3);
        memo.put(left, d);
        return d;
    }
    
    private List<Integer> fillSquares(int n, Map<Integer, Integer> memo) {
        int curr = 1;
        List<Integer> list = new ArrayList();
        int d = (int)Math.pow(curr, 2);
        while (d < n) {
            list.add(d);
            memo.put(d, 1);
            curr++;
            d = (int)Math.pow(curr, 2);
        }
        if (d == n)
            return null;
        return list;
    }

我这样打电话:

numSquares(7168)

所有测试用例都通过了(即使是复杂的用例),但此失败。我怀疑自己的备忘有问题,但无法准确指出。任何帮助将不胜感激。

2 个答案:

答案 0 :(得分:1)

您已根据要获得的值来确定备忘,但这并没有考虑index的值,这实际上限制了可以用来获得该值的能力。这意味着如果(在极端情况下)index为0,则只能减少剩下的平方(1²),这很少是形成该数字的最佳方法。因此,memo.set()首先将注册一个非最佳平方数,稍后将通过递归树中待处理的其他递归调用进行更新。

如果添加一些条件调试代码,则会看到多次调用map.set的相同值left,并且具有不同的值。这不好,因为这意味着if (memo.has(left))块将在尚未保证该值是最佳值的情况下执行。

您可以通过在记事密钥中加入index来解决此问题。这会增加用于记忆的空间,但是它将起作用。我想你可以解决这个问题。

但是根据Lagrange's four square theorem,每个自然数都可以写为最多四个平方的和,因此返回值永远不能为5或更大。当您通过该数量的术语时,可以简化递归。这降低了使用备忘录的好处。

最后,fillSquares中有一个错误:当它是一个完美的正方形时,它也应该添加n本身,否则您将找不到应该返回1的解。

答案 1 :(得分:0)

  • 不确定您的错误,这是一个简短的动态编程解决方案:

Java

public class Solution {
    public static final int numSquares(
        final int n
    ) {
        int[] dp = new int[n + 1];
        Arrays.fill(dp, Integer.MAX_VALUE);
        dp[0] = 0;

        for (int i = 1; i <= n; i++) {
            int j = 1;
            int min = Integer.MAX_VALUE;

            while (i - j * j >= 0) {
                min = Math.min(min, dp[i - j * j] + 1);
                ++j;
            }

            dp[i] = min;
        }

        return dp[n];
    }
}

C ++

// Most of headers are already included;
// Can be removed;
#include <iostream>
#include <cstdint>
#include <vector>
#include <algorithm>

// The following block might slightly improve the execution time;
// Can be removed;
static const auto __optimize__ = []() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    return 0;
}();


#define MAX INT_MAX

using ValueType = std::uint_fast32_t;

struct Solution {
    static const int numSquares(
        const int n
    ) {
        if (n < 1) {
            return 0;
        }

        static std::vector<ValueType> count_perfect_squares{0};

        while (std::size(count_perfect_squares) <= n) {
            const ValueType len = std::size(count_perfect_squares);
            ValueType count_squares = MAX;

            for (ValueType index = 1; index * index <= len; ++index) {
                count_squares = std::min(count_squares, 1 + count_perfect_squares[len - index * index]);
            }

            count_perfect_squares.emplace_back(count_squares);
        }

        return count_perfect_squares[n];
    }
};

int main() {
    std::cout <<  std::to_string(Solution().numSquares(12) == 3) << "\n";

    return 0;
}

Python

  • 在这里我们可以简单地使用lru_cache
class Solution:
    dp = [0]
    @functools.lru_cache
    def numSquares(self, n):
        dp = self.dp
        while len(dp) <= n:
            dp += min(dp[-i * i] for i in range(1, int(len(dp) ** 0.5 + 1))) + 1, 
        return dp[n]

以下是LeetCode的官方解决方案,并附有评论:

Java:DP

class Solution {

  public int numSquares(int n) {
    int dp[] = new int[n + 1];
    Arrays.fill(dp, Integer.MAX_VALUE);
    // bottom case
    dp[0] = 0;

    // pre-calculate the square numbers.
    int max_square_index = (int) Math.sqrt(n) + 1;
    int square_nums[] = new int[max_square_index];
    for (int i = 1; i < max_square_index; ++i) {
      square_nums[i] = i * i;
    }

    for (int i = 1; i <= n; ++i) {
      for (int s = 1; s < max_square_index; ++s) {
        if (i < square_nums[s])
          break;
        dp[i] = Math.min(dp[i], dp[i - square_nums[s]] + 1);
      }
    }
    return dp[n];
  }
}

Java:贪婪

class Solution {
  Set<Integer> square_nums = new HashSet<Integer>();

  protected boolean is_divided_by(int n, int count) {
    if (count == 1) {
      return square_nums.contains(n);
    }

    for (Integer square : square_nums) {
      if (is_divided_by(n - square, count - 1)) {
        return true;
      }
    }
    return false;
  }

  public int numSquares(int n) {
    this.square_nums.clear();

    for (int i = 1; i * i <= n; ++i) {
      this.square_nums.add(i * i);
    }

    int count = 1;
    for (; count <= n; ++count) {
      if (is_divided_by(n, count))
        return count;
    }
    return count;
  }
}

Java:广度优先搜索

class Solution {
  public int numSquares(int n) {

    ArrayList<Integer> square_nums = new ArrayList<Integer>();
    for (int i = 1; i * i <= n; ++i) {
      square_nums.add(i * i);
    }

    Set<Integer> queue = new HashSet<Integer>();
    queue.add(n);

    int level = 0;
    while (queue.size() > 0) {
      level += 1;
      Set<Integer> next_queue = new HashSet<Integer>();

      for (Integer remainder : queue) {
        for (Integer square : square_nums) {
          if (remainder.equals(square)) {
            return level;
          } else if (remainder < square) {
            break;
          } else {
            next_queue.add(remainder - square);
          }
        }
      }
      queue = next_queue;
    }
    return level;
  }
}

Java:使用数学的最​​有效解决方案

  • 运行时:O(N ^ 0.5)
  • 内存:O(1)
class Solution {

  protected boolean isSquare(int n) {
    int sq = (int) Math.sqrt(n);
    return n == sq * sq;
  }

  public int numSquares(int n) {
    // four-square and three-square theorems.
    while (n % 4 == 0)
      n /= 4;
    if (n % 8 == 7)
      return 4;

    if (this.isSquare(n))
      return 1;
    // enumeration to check if the number can be decomposed into sum of two squares.
    for (int i = 1; i * i <= n; ++i) {
      if (this.isSquare(n - i * i))
        return 2;
    }
    // bottom case of three-square theorem.
    return 3;
  }
}