我目前的解决方案是:
def distance(state):
target_state = (1,2,3,4,5,6,7,8,0)
target_matrix = np.reshape(np.asarray(list(target_state)),(-1,3))
reshaped_matrix = np.reshape(np.asarray(list(state)),(-1,3))
dist = 0
for i in range(1,9):
dist = dist + (abs(np.where(target_matrix == i)[0][0]
- np.where(reshaped_matrix == i)[0][0]) +
abs(np.where(target_matrix == i)[1][0]
- np.where(reshaped_matrix == i)[1][0]))
return dist
答案 0 :(得分:2)
怎么样
import numpy as np
def summed_manhattan(state):
shuffle = np.reshape((np.array(state)-1) % 9, (3, 3))
all_dists = np.abs(np.unravel_index(shuffle, (3, 3)) - np.indices((3, 3)))
all_dists.shape = 2, 9
gap = np.where(shuffle.ravel() == 8)[0][0]
return all_dists[:, :gap].sum() + all_dists[:, gap + 1 :].sum()
这可以避免重复调用where(最多为O(n ^ 2))来改善您的解决方案。相反,利用target_state的简单结构,它为每个索引计算状态,将索引转换为持有相同值的target_state;排列存储在随机播放中。这个小技巧使算法O(n)和奖励变得容易矢量化。
这个解决方案是最优的,因为显然不能比O(n)做得更好。