我建立了一个DQN来学习井字游戏,我发现在训练过程中权重没有变化。我的目标/预测或损失/优化器的实施似乎有些问题。我是pytorch的新手,所以我不知道自己搞砸了。这是相关的代码:
def get_pred_and_target(st, next_state, act, player, discount):
pred = torch.tensor([net(torch.tensor(st).float()).squeeze().detach().numpy()[act]])
# Define reward
reward = 0.
winner, game_status = check_result(next_state)
if game_status == 'Done' and winner == player:
reward = 1.
if game_status == 'Done' and winner != player:
reward = -1.
if game_status == 'Draw':
reward = 1.
# Define target
if next_state.count(0) == 0:
target = torch.tensor([reward], requires_grad=True).float()
else:
target = torch.tensor([reward]).float() + discount * torch.max(
target_net(torch.tensor(st).float()))
return pred, target
# Training against intelligent agent
num_epochs = 10000
epsilon_array = np.linspace(0.8, 0.1, num_epochs) # epsilon decays with every epoch
lr_array = np.linspace(1e-2, 1e-9, num_epochs)
results = []
results_val = []
percentages = []
percentages_val = []
preds = torch.tensor([]).float()
targets = torch.tensor([]).float()
training = True
validation = False
playing = False
batch_size = 5
update_target = 20
for param in net.parameters():
param.requires_grad = True
if training:
for epoch in range(num_epochs):
# Define Optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr_array[epoch], weight_decay=1e-8)
# Produce batch
for i in range(batch_size):
# Clear Board
state = [0, 0, 0, 0, 0, 0, 0, 0, 0]
epsilon = epsilon_array[epoch]
game_status = 'Not Done'
winner = None
players_turn = np.random.choice([0, 1])
while game_status == 'Not Done':
if players_turn == 0: # X's move
# print("\nAI X's turn!")
action = select_action(state, epsilon)
new_state = play_move(state, 1, action)
else: # O's move
# print("\nAI O's turn!")
action = select_random_action(state)
new_state = play_move(state, -1, action)
# get pred and target for Q(s,a)
pred, target = get_pred_and_target(state, new_state, action, 1, discount=0.99)
# update batch
preds = torch.cat([preds, pred])
targets = torch.cat([targets, target])
# update state
state = new_state.copy()
# print_board(new_state)
winner, game_status = check_result(state)
if winner is not None:
# print(str(winner) + ' won!')
if winner == 1:
results.append('X')
else:
results.append('O')
else:
players_turn = (players_turn + 1) % 2
if game_status == 'Draw':
# print('Draw!')
results.append('Draw')
loss = loss_fn(preds, targets)
optimizer.zero_grad()
# Backward pass
loss.backward()
# Update
optimizer.step()
# Clear batch
preds = torch.tensor([]).float()
targets = torch.tensor([]).float()
# update target net
if epoch % update_target == 0:
print('Epoch: ' + str(epoch))
print(torch.mean(loss))
print(pred)
target_net = pickle.loads(pickle.dumps(net))
percentage = results[-700:].count('O')/7
percentages.append(percentage)
print(f'Random player win percentage in the last 700 games: {percentage} %')
print('Training Complete')