曼哈顿距离

时间:2017-01-24 22:53:48

标签: python python-3.x numpy

计算manhattan distances

的最佳方式是什么?

我目前的解决方案是:

def distance(state):
    target_state = (1,2,3,4,5,6,7,8,0)
    target_matrix = np.reshape(np.asarray(list(target_state)),(-1,3))
    reshaped_matrix = np.reshape(np.asarray(list(state)),(-1,3))
    dist = 0
    for i in range(1,9):
        dist = dist + (abs(np.where(target_matrix == i)[0][0]
                           - np.where(reshaped_matrix == i)[0][0]) +
                       abs(np.where(target_matrix == i)[1][0]
                           - np.where(reshaped_matrix == i)[1][0]))

    return dist

1 个答案:

答案 0 :(得分:2)

怎么样

import numpy as np

def summed_manhattan(state):
    shuffle = np.reshape((np.array(state)-1) % 9, (3, 3))
    all_dists = np.abs(np.unravel_index(shuffle, (3, 3)) - np.indices((3, 3)))
    all_dists.shape = 2, 9
    gap = np.where(shuffle.ravel() == 8)[0][0]
    return all_dists[:, :gap].sum() + all_dists[:, gap + 1 :].sum()

这可以避免重复调用where(最多为O(n ^ 2))来改善您的解决方案。相反,利用target_state的简单结构,它为每个索引计算状态,将索引转换为持有相同值的target_state;排列存储在随机播放中。这个小技巧使算法O(n)和奖励变得容易矢量化。

这个解决方案是最优的,因为显然不能比O(n)做得更好。