我正在从Maxim Lapan的Deep_Reinforcement_Learning_Hands-On书中学习RL。这是他们的REINFORCE算法代码。 https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On/blob/master/Chapter09/02_cartpole_reinforce.py 但是想在不使用库ptan的情况下实现此代码,因为我不知道它在做什么。
我尝试执行下面的代码
env = gym.make("CartPole-v0")
writer = SummaryWriter(comment="-cartpole-reinforce")
net = PGN(env.observation_space.shape[0], env.action_space.n)
print(net)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
total_rewards = []
step_idx = 0
done_episodes = 0
batch_episodes = 0
batch_states, batch_actions, batch_qvals = [], [], []
cur_rewards = []
state=env.reset()
for step_idx in range(100):
batch_states.append(state)
state=torch.Tensor(state)
action=int(net(state).detach().max(0)[1])
batch_actions.append(action)
state,r,done,_=env.step(action)
cur_rewards.append(r)
if done:
batch_qvals.extend(calc_qvals(cur_rewards))
reward=sum(cur_rewards)
cur_rewards.clear()
batch_episodes += 1
state=env.reset()
# handle new rewards
if done:
done_episodes += 1
total_rewards.append(reward)
mean_rewards = float(np.mean(total_rewards[-100:]))
print("%d: reward: %6.2f, mean_100: %6.2f, episodes: %d" % (
step_idx, reward, mean_rewards, done_episodes))
writer.add_scalar("reward", reward, step_idx)
writer.add_scalar("reward_100", mean_rewards, step_idx)
writer.add_scalar("episodes", done_episodes, step_idx)
if mean_rewards > 195:
print("Solved in %d steps and %d episodes!" % (step_idx, done_episodes))
break
if batch_episodes < EPISODES_TO_TRAIN:
continue
optimizer.zero_grad()
states_v = torch.FloatTensor(batch_states)
batch_actions_t = torch.LongTensor(batch_actions)
batch_qvals_v = torch.FloatTensor(batch_qvals)
logits_v = net(states_v)
log_prob_v = F.log_softmax(logits_v, dim=1)
log_prob_actions_v = batch_qvals_v * log_prob_v[range(len(batch_states)), batch_actions_t]
loss_v = -log_prob_actions_v.mean()
loss_v.backward()
optimizer.step()
batch_episodes = 0
batch_states.clear()
batch_actions.clear()
batch_qvals.clear()
writer.close()
``
But this doesn't converge at all, the reward is never more than 10