为什么翻译的数独求解器比原来慢?

时间:2016-08-28 15:29:37

标签: java python numpy sudoku

我将Java Sudoku解算器转录为python。一切正常,但解决需要花费2分钟,而相同的拼图在Java中只需要几秒钟。此外,所需的迭代数量完全相同。我错过了什么吗?

import numpy as np
def solve_recursive(puzzle, pos):
        if(pos == 81):
            print puzzle
            return True
        if(puzzle[pos] != 0):
            if (not solve_recursive(puzzle, pos+1)):
                return False
            else:
                return True

        row = np.copy(puzzle[pos//9*9:pos//9*9+9])
        col = np.copy(puzzle[pos%9::9])
        short = (pos%9)//3*3 + pos//27*27
        square = np.concatenate((puzzle[short:short+3],puzzle[short+9:short+12],puzzle[short+18:short+21]))

        for i in range(1,10):
            puzzle[pos] = i
            if(i not in row and i not in col and i not in square and solve_recursive(puzzle, pos+1)):
                return True

        puzzle[pos] = 0
        return False
puzzle = np.array([[0,0,0,0,0,0,0,8,3],
              [0,2,0,1,0,0,0,0,0],
              [0,0,0,0,0,0,0,4,0],
              [0,0,0,6,1,0,2,0,0],
              [8,0,0,0,0,0,9,0,0],
              [0,0,4,0,0,0,0,0,0],
              [0,6,0,3,0,0,5,0,0],
              [1,0,0,0,0,0,0,7,0],
              [0,0,0,0,0,8,0,0,0]])
solve_recursive(puzzle.ravel(), 0)

编辑:

正如@hpaulj所建议的,我重新设计了我的代码以使用numpy 2D数组:

import numpy as np
def solve_recursive(puzzle, pos):
        if pos == (0,9):
            print puzzle
            raise Exception("Solution")
        if(puzzle[pos] != 0):
            if(pos[0] == 8):
                solve_recursive(puzzle, (0,pos[1]+1))
                return
            elif pos[0] < 8:
                solve_recursive(puzzle, (pos[0]+1, pos[1]))
                return

        for i in range(1,10):
            if(i not in puzzle[pos[0]] and i not in puzzle[:,pos[1]] and i not in puzzle[pos[0]//3*3:pos[0]//3*3+3,pos[1]//3*3:pos[1]//3*3+3]):
                puzzle[pos] = i
                if(pos[0] == 8):
                    solve_recursive(puzzle, (0,pos[1]+1))
                elif pos[0] < 8:
                    solve_recursive(puzzle, (pos[0]+1, pos[1]))
        puzzle[pos] = 0
puzzle = np.array([[0,0,0,0,0,0,0,8,3],
          [0,2,0,1,0,0,0,0,0],
          [0,0,0,0,0,0,0,4,0],
          [0,0,0,6,1,0,2,0,0],
          [8,0,0,0,0,0,9,0,0],
          [0,0,4,0,0,0,0,0,0],
          [0,6,0,3,0,0,5,0,0],
          [1,0,0,0,0,0,0,7,0],
          [0,0,0,0,0,8,0,0,0]])
solve_recursive(puzzle, (0,0))

忽略在递归调用的底部抛出异常的事实是相当不优雅的,这比我原来的解决方案快得多。使用像链接的Norvig求解器这样的词典是一个合理的选择吗?

1 个答案:

答案 0 :(得分:2)

我修改了你的函数来打印pos并保持其被调用次数的运行计数。我会提前停止它。

停在pos==46会导致1190次通话,只有轻微的可见延迟。但是对于47,计数是416621,一分钟或更长时间。

假设它正在进行某种递归搜索,那么这个问题的难度在46到47之间就会出现。

是的,作为解释语言的Python将比Java运行得慢。可能的解决方案包括弄清楚为什么在递归调用中会出现这种跳跃。或者提高每次通话的速度。

您设置了9x9 numpy数组,但随后立即对其进行了调整。然后,函数本身将该板视为81个值的列表。这意味着选择行和列以及子矩阵比阵列仍为2d要复杂得多。实际上,数组只是一个列表。

我可以想象两种加速通话的方法。一种是重新编码以使用列表板。对于小型数组和迭代操作,列表比数组具有更少的开销,因此通常更快。另一种方法是对其进行编码以真正利用数组的二维特性。只有当numpy使用已编译的代码执行大多数操作时,numpy解决方案才有用。对数组的迭代很慢。

==================

更改函数以使其与平面列表而不是raveled数组一起使用,运行速度更快。对于47的最大位置,它在15秒内运行,而原始位置为1m 15s(相同的板和迭代计数)。

我正在清理一个2d numpy阵列版本,但没有让它更快。

纯列表版本也适合在pypy上运行得更快。

使用2d数组的代码的一部分

    r,c = np.unravel_index(pos, (9,9))            
    if(puzzle[r,c] != 0):
        return solve_numpy(puzzle, pos+1)
    row = puzzle[r,:].copy()
    col = puzzle[:,c].copy()
    r1, c1 = 3*(r//3), 3*(c//3)
    square = puzzle[r1:r1+3, c1:c1+3].flatten()
    for i in range(1,10):
        puzzle[r,c] = i
        if(i not in row and i not in col and i not in square):
            if solve_numpy(puzzle, pos+1):
                return True
    puzzle[r,c] = 0

索引更清晰,但没有速度提升。除了更简单的索引之外,它并没有充分利用整个数组操作。

list版本看起来与原版不同,但速度要快得多:

    row = puzzle[pos//9*9:pos//9*9+9]
    col = puzzle[pos%9::9]
    short = (pos%9)//3*3 + pos//27*27
    square = puzzle[short:short+3] + \
             puzzle[short+9:short+12] + \
             puzzle[short+18:short+21]

http://norvig.com/sudoku.html由AI专家讨论使用pythoN的数独求解方法。

使用此Norvig解算器,您的网格解决方案需要0.01秒。信息主要存储在词典中。你的案例很简单,可以通过他的2个基本任务策略来解决。没有搜索解决方案的速度非常快。