Python方法更改字典

时间:2016-12-22 18:00:10

标签: python numpy dictionary

我正在构建一个学习玩tic tac toe的小蟒蛇脚本。我的过程是存储在游戏中进行的每个动作,并根据它是否导致获胜结果进行评分。最终我会尝试在很多轮比赛中训练这个。

我的问题在于我的update_weights()方法。我希望它采取存储的移动(从board对象访问并表示为列表[row,col])并遍历该移动列表。然后,该方法应引用板的存储权重((3,3)numpy数组的字典)并更新相应的权重以进行适当的移动。

e.g。假设获胜。在获胜序列中,移动#2位于棋盘位置[0,1]。该方法应该访问权重字典(键是移动#)并将数组[0,1]的位置乘以因子1.05。

问题是我的方法正在改变我的权重字典中的所有数组,而不仅仅是与正确的移动#键相关联的数组。我无法弄清楚这是怎么回事。

import numpy as np
import random

    class ttt_board():

        def __init__(self):
            self.board_state = np.array([[0,0,0],[0,0,0],[0,0,0]])
            self.board_weight = self.reset_board_weights()
            self.moves = []

        def reset_board_weights(self):
            board_weight_instance = np.zeros((3,3))
            board_weight_instance[board_weight_instance >= 0] = 0.5

            board_weight = {0: board_weight_instance,
                            1: board_weight_instance,
                            2: board_weight_instance,
                            3: board_weight_instance,
                            4: board_weight_instance}

            return board_weight

        def reset_board(self):
            self.board_state = np.array([[0,0,0],[0,0,0],[0,0,0]])

        def reset_moves(self):
            self.moves = []

        def is_win(self):
            board = self.board_state 
            if board.trace() == 3 or np.flipud(board).trace() == 3:
                return True
            for i in range(3):
                if board.sum(axis=0)[i] == 3 or board.sum(axis=1)[i] == 3:
                    return True
            else:
                return False

        def is_loss(self):
            board = self.board_state 
            if board.trace() == 12 or np.flipud(board).trace() == 12:
                return True
            for i in range(3):
                if board.sum(axis=0)[i] == 12 or board.sum(axis=1)[i] == 12:
                    return True
            else:
                return False

        def is_tie(self):
            board = self.board_state
            board_full = True
            for i in range(len(board)):
                for k in range(len(board)):
                    if board[i][k] == 0:
                        board_full = False
            if board_full and  not self.is_win() and not self.is_loss():
                return True
            else:
                return False

        def update_board(self,player,space):
            #takes player as 1 or 4
            #takes space as list [0,0]
            self.board_state[space[0],space[1]] = player

            if player == 1:
                self.store_move(space)
            return

        def get_avail_spots(self):
            avail_spots = []
            board = self.board_state
            for i in range(len(board)):
                for k in range(len(board)):
                    if board[i][k] == 0:
                        avail_spots.append([i,k])
            return avail_spots

        def gen_next_move(self):
            avail_spots = self.get_avail_spots()
            move = random.randrange(len(avail_spots))
            return avail_spots[move]

        def update_weights(self,win):
            moves = self.moves
            if win:
                factor = 1.05
            else:
                factor= 0.95
            for i in range(len(moves)):
                row = moves[i][0]
                col = moves[i][1]
                old_weight = self.board_weight[i][row,col]
                new_weight = old_weight*factor
                self.board_weight[i][row,col] = new_weight
            return

        def store_move(self,move):
            self.moves.append(move)
            return


    if __name__ == '__main__':

        board = ttt_board()

        while not board.is_win() and not board.is_loss() and not board.is_tie():
            try:
                board.update_board(1,board.gen_next_move())
                board.update_board(4,board.gen_next_move())
            except ValueError:
                break

        if board.is_win():
            board.update_weights(1)
            print('Player 1 wins: {w}'.format(w=board.is_win()))
        elif board.is_loss():
            board.update_weights(0)
            print('Player 2 wins: {l}'.format(l=board.is_loss()))
        elif board.is_tie():
            print('Game ends in tie: {t}'.format(t=board.is_tie()))

        print('Here is the final board')
        print(board.board_state)
        print(board.board_weight)
        print(board.moves)

正如您通过运行脚本所看到的,单个游戏后打印的权重字典对于每个键具有相同的数组值。 我希望每个数组只能在一个位置进行更改,因为它只能用于与其关联的键对应的移动#。

1 个答案:

答案 0 :(得分:2)

问题是您在字典

中的board_weight_instance数组上共享相同的引用
board_weight_instance = np.zeros((3,3))
board_weight_instance[board_weight_instance >= 0] = 0.5

board_weight = {0: board_weight_instance,
                1: board_weight_instance,
                2: board_weight_instance,
                3: board_weight_instance,
                4: board_weight_instance}

我会在字典理解中这样做,使用辅助方法为每个元素创建一个新的引用:

@staticmethod
def create_element():
   board_weight_instance = np.zeros((3,3))
   board_weight_instance[:] = 0.5  # simpler than your method
   return board_weight_instance

board_weight = {i:self.create_element() for i in range(0,5)}

在你的情况下,为什么甚至在你可以使用list时使用字典:没有散列,更快的处理:

board_weight = [self.create_element() for _ in range(0,5)]

你可以用同样的方式访问它