我正在尝试构建一个模型,该模型将使用带有 Actor-Critic 策略的强化学习来预测股票的买入或卖出信号。 我是机器学习和 PyTorch 的新手,在我对这个问题的研究中,我意识到我没有在某种程度上学习任何东西......我的意思是,如果我只是在看 wandb 图表正在进化,那么看起来像是在学习一些东西。
此外,我正在通过这两个函数保存和加载模型:
def save_model(self, path: str, name: str):
torch.save(self.actor.state_dict(), os.path.join(path, f"{name}_actor"))
torch.save(self.critic.state_dict(), os.path.join(path, f"{name}_critic"))
def load_model(self, path: str, name: str):
self.actor.load_state_dict(torch.load(os.path.join(path, f"{name}_actor")))
self.critic.load_state_dict(torch.load(os.path.join(path, f"{name}_critic")))
但我发现真正奇怪的是,在我的选择操作功能中,我总是从 2(买入或卖出)中选择操作 1(卖出)。动作与 1 不同的唯一时间是 random_for_egredy 大于 epsilon
def select_action(self, state, epsilon):
random_for_egreedy = torch.rand(1)[0]
if random_for_egreedy > epsilon:
with torch.no_grad():
state = torch.Tensor(state.values).to(device)
actor_action = self.actor(state)
action = torch.argmax(actor_action)
action = action.item()
else:
action = self.gym.action_space.sample()
return action
这是我的优化功能:
def optimize(self):
if len(self.memory) < self.config.batch_size:
return
self.optimizer_actor.zero_grad()
self.optimizer_critic.zero_grad()
state, action, new_state, reward, done = self.memory.sample(batch_size=self.config.batch_size)
state = torch.Tensor(np.array(state)).to(device)
new_state = torch.Tensor(np.array(new_state)).to(device)
reward = torch.Tensor(reward).to(device)
action = torch.LongTensor(action).to(device)
done = torch.Tensor(done).to(device)
dist = torch.distributions.Categorical(self.actor(state))
advantage = reward + (1 - done) * self.config.gamma * self.critic(new_state) - self.critic(state)
critic_loss = advantage.pow(2).mean()
self.optimizer_critic.zero_grad()
critic_loss.backward()
self.optimizer_critic.step()
actor_loss = -dist.log_prob(action) * advantage.detach()
self.optimizer_actor.zero_grad()
actor_loss.mean().backward()
self.optimizer_actor.step()
wandb.log({"Actor Loss": actor_loss.mean(), "Critic Loss": critic_loss})
这是我的训练循环:
for ep in range(conf.num_episode):
state = env.reset()
step = 0
# qnet_agent.reset_running_loss()
wandb.log({"Episode": ep})
if ep % save_after_episode == 0:
qnet_agent.save_model("checkpoints", model_save_name)
while True:
wandb.log({"step": step})
step += 1
frames_total += 1
epsilon = calculate_epsilon(frames_total)
action = qnet_agent.select_action(state, epsilon)
wandb.log({"last action": action})
new_state, reward, done, info = env.step(action)
wandb.log({"Current profit": info['current_profit']})
wandb.log({"Total profit": info['total_profit']})
wandb.log({"reward": reward})
memory.push(state, action, new_state, reward, done)
qnet_agent.optimize()
state = new_state
if done:
steps_total.append(step)
break
你们中的任何人能告诉我我是否遗漏了什么吗?还是我做错了什么?