在PyTorch中实施的SARSA代理有什么问题?

时间:2019-11-19 04:31:03

标签: pytorch reinforcement-learning sarsa

我使用pyTorch和OpenAI-GYM环境(​​CartPole)。

但是它的分数并没有总体上提高。

我不知道为什么它运行不正确。

我的代码有什么问题?

import gym
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

它是导入设置。

class Sarsa(nn.Module):
  def __init__(self):
    super(Sarsa, self).__init__()

    self.fc1 = nn.Linear(4, 64)
    self.fc2 = nn.Linear(64,2)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    q = self.fc2(x)
    return q

  def sample_action(self, obs, epsilon):
    out = self.forward(obs)
    coin = random.random()
    if coin < epsilon:
      return random.randint(0, 1)
    else:
      return out.argmax().item()

它是代理代码。 具有动作值网络。(Q函数近似)


def train(q, gamma, optimizer, sarsa):
  s, a, r, s_prime, a_prime = sarsa
  q_out = q(torch.tensor(s, dtype=torch.float))
  q_a = q_out[a]
  q_prime_out = q(torch.tensor(s_prime, dtype=torch.float))
  q_prime_a = q_prime_out[a_prime]
  td_target = r + gamma * q_prime_a
  loss = nn.MSELoss()
  loss = loss(td_target.detach(), q_a)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

这是火车代码。 它与一对SARSA一起训练一个步骤。 我在td_target和v之间使用MSE损失函数。

def main():
  env = gym.make('CartPole-v1')
  q = Sarsa()

  gamma = 0.98
  optimizer = optim.Adam(q.parameters(), lr=0.0005)

  score = 0.0
  print_interval = 20

  for episode in range(10000):
    epsilon = max(0.01, 0.1 - 0.01*(episode/200))
    s = env.reset()
    sar = None

    for step in range(600):
      a = q.sample_action(torch.tensor(s, dtype=torch.float), epsilon)
      s_prime, r, done, info = env.step(a)

      if done:
        break

      if sar is not None:
        s_last, a_last, r_last = sar
        sarsa = (s_last, a_last, r_last, s, a)
        train(q, gamma, optimizer, sarsa)
      sar = (s, a, r)
      s = s_prime
      score += r

    if episode % print_interval == 0 and episode != 0:
      print("episode {}'s avg score : {}".format(episode, score/print_interval))
      score = 0.0

  env.close()

if __name__ == '__main__':
  main()

这是我代码中的主要功能。

0 个答案:

没有答案