numpy:将矩阵提升为幂会产生奇怪的结果

时间:2020-04-06 16:31:07

标签: python-3.x numpy

我对KnightDialer problem on Leetcode的解决方案:

import numpy as np
class Solution:
  def knightDialer(self, N: int) -> int:
    M = np.matrix([[0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
                   [0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
                   [0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
                   [1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
                   [0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
                   [0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
                   [0, 0, 1, 0, 1, 0, 0, 0, 0, 0]])
    return np.sum(M**(N-1)) % (10**9 + 7)

它适用于N值最大为51。对于N = 1,它正确返回10,对于N = 2,它正确返回20,对于N = 3,它正确返回46,依此类推。在N> 51时,它将停止产生准确的结果(对于N = 52,它将返回107679695,而不是 690023703)。我不知道为什么,但是以某种方式将矩阵提高到> 51的幂会导致结果不准确。

我尝试将M**(N-1)替换为np.linalg.matrix_power(M, (N-1)),但输出仍然不准确。我的直觉是幕后有一些麻木的“魔术”,但我不确定是什么。

1 个答案:

答案 0 :(得分:1)

Numpy很费劲,因为它与固定大小的整数(例如int32int64(在此使用哪个取决于您的python安装)一起使用。这样对矩阵求幂很快会使条目大于该数据类型的限制,从而导致截断。

对于常规Python整数,可以使用这种实现方式(先将矩阵乘以然后再应用模数约简),因为常规Python整数可以具有更高的值。例如:

def mm(A, B):
    return [[sum([x*y for (x, y) in zip(row, col)]) for col in zip(*B)] for row in A]

def knightDialer(N: int) -> int:
    M = [[0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
        [1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0, 0, 0]]
    N = N - 1
    res = [[int(i == j) for j in range(len(M))] for i in range(len(M))]
    while N:
        if N & 1: res = mm(res, M)
        M = mm(M, M)
        N >>= 1
    print(M)
    return sum([sum(i) for i in zip(*res)]) % (10**9 + 7)

在幂运算过程中应用模块化约简,可使numpy矩阵正常工作而不会用完位:

def knightDialer(N: int) -> int:
    M = np.matrix([[0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
                   [0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
                   [0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
                   [1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
                   [0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
                   [0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
                   [0, 0, 1, 0, 1, 0, 0, 0, 0, 0]], dtype=np.int64)
    N = N - 1
    res = np.eye(M.shape[0], dtype=np.int64)
    while N:
        if N & 1: res = res * M % (10**9 + 7)
        M = M * M % (10**9 + 7)
        N >>= 1
    return np.sum(res) % (10**9 + 7)

在安装时必须使用dtype=np.int64,默认整数类型为int32int32足够大,可以容纳10**9 + 7以下的数字,但是在矩阵乘法过程中,两个这样的数字之间将存在乘积(然后也将它们相加),并且如果{{ 1}}。