我对KnightDialer problem on Leetcode的解决方案:
import numpy as np
class Solution:
def knightDialer(self, N: int) -> int:
M = np.matrix([[0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0]])
return np.sum(M**(N-1)) % (10**9 + 7)
它适用于N值最大为51。对于N = 1,它正确返回10,对于N = 2,它正确返回20,对于N = 3,它正确返回46,依此类推。在N> 51时,它将停止产生准确的结果(对于N = 52,它将返回107679695,而不是 690023703)。我不知道为什么,但是以某种方式将矩阵提高到> 51的幂会导致结果不准确。
我尝试将M**(N-1)
替换为np.linalg.matrix_power(M, (N-1))
,但输出仍然不准确。我的直觉是幕后有一些麻木的“魔术”,但我不确定是什么。
答案 0 :(得分:1)
Numpy很费劲,因为它与固定大小的整数(例如int32
或int64
(在此使用哪个取决于您的python安装)一起使用。这样对矩阵求幂很快会使条目大于该数据类型的限制,从而导致截断。
对于常规Python整数,可以使用这种实现方式(先将矩阵乘以然后再应用模数约简),因为常规Python整数可以具有更高的值。例如:
def mm(A, B):
return [[sum([x*y for (x, y) in zip(row, col)]) for col in zip(*B)] for row in A]
def knightDialer(N: int) -> int:
M = [[0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0]]
N = N - 1
res = [[int(i == j) for j in range(len(M))] for i in range(len(M))]
while N:
if N & 1: res = mm(res, M)
M = mm(M, M)
N >>= 1
print(M)
return sum([sum(i) for i in zip(*res)]) % (10**9 + 7)
在幂运算过程中应用模块化约简,可使numpy矩阵正常工作而不会用完位:
def knightDialer(N: int) -> int:
M = np.matrix([[0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0]], dtype=np.int64)
N = N - 1
res = np.eye(M.shape[0], dtype=np.int64)
while N:
if N & 1: res = res * M % (10**9 + 7)
M = M * M % (10**9 + 7)
N >>= 1
return np.sum(res) % (10**9 + 7)
在安装时必须使用dtype=np.int64
,默认整数类型为int32
。 int32
足够大,可以容纳10**9 + 7
以下的数字,但是在矩阵乘法过程中,两个这样的数字之间将存在乘积(然后也将它们相加),并且如果{{ 1}}。