在强化学习中更新权重时出现记忆问题

时间:2019-12-11 08:57:28

标签: python python-3.x out-of-memory reinforcement-learning

我正在使用RL的PPO(最近策略优化)算法。每当有更新时,我的RAM内存就会急剧增加。基本上,我要运行2500个剧集,每个剧集的max_timestep为400,并且对于每个剧集,它在400步后向后传播并更新权重,即

               if time_step % update_timestep == 0:
                    self.update(memory)
                    memory.clear_memory()
                    time_step = 0

更新权重的更新功能是

def update(self, memory):
        # Monte Carlo estimate of rewards:
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards:
        rewards = torch.tensor(rewards).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach()
        old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach()
        old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach()

        # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Evaluating old actions and values :
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # Finding the ratio (pi_theta / pi_theta__old):
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss:
            advantages = rewards - state_values.detach()   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())

'self.K_epochs'是PPO算法的伪代码,它为K个时期运行该特定循环(在我的情况下,该循环以'self.K_epochs = 80'运行80个回波)。 注意:如果您在调用更新函数后看到的话,从上一片段中我也正在调用另一个清除内存的函数,即

    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

所以基本上,不应该有任何内存问题,因为从self.K_epochs for循环的update函数中,我还将所有渐变设为零,然后更新了新的权重。 我的培训课程仅持续1000集,然后RAM满了,代码终止了。

我附上下面的完整代码,供您参考,

#!/usr/bin/env python

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
import gym
import numpy as np
from ur5_env_kdl import Robot2Env
import matplotlib.pyplot as plt
from matplotlib import gridspec
import os
import sys
import pprint
import copy
from copy import deepcopy
import PyKDL
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, action_std):
        super(ActorCritic, self).__init__()
        # action mean range -1 to 1
        self.policy_loss = []
        #self.entropy_loss = []
        self.actor =  nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.LeakyReLU(),
                nn.Linear(64, 32),
                nn.LeakyReLU(),
                nn.Linear(32, action_dim),
                nn.Tanh()
                )
        #print("LOSS", self.actor[0].weights)  
        # critic
        self.critic = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.LeakyReLU(),
                nn.Linear(64, 32),
                nn.LeakyReLU(),
                nn.Linear(32, 1)
                )
        self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
        #print("VAR", self.action_var)
    def forward(self):
        raise NotImplementedError

    def act(self, state, memory):
        #print("self.actor", self.actor)
        action_mean = self.actor(state)
        #print("action_mean", action_mean)
        cov_mat = torch.diag(self.action_var).to(device)
        #print("cov_mat", cov_mat)
        dist = MultivariateNormal(action_mean, cov_mat)
        #print("DIST", dist)
        action = dist.sample()
        #print("action", action)
        action_logprob = dist.log_prob(action)
        #print("action_logprob", action_logprob)
        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)

        return action.detach()

    def evaluate(self, state, action):   
        action_mean = torch.squeeze(self.actor(state))
        #print("action_mean_squeeze", action_mean)
        action_var = self.action_var.expand_as(action_mean)
        #print("EV action_var", action_var)
        cov_mat = torch.diag_embed(action_var).to(device)
        #print("EV cov_mat", cov_mat)
        dist = MultivariateNormal(action_mean, cov_mat)
        action_logprobs = dist.log_prob(torch.squeeze(action))
        policy_loss = -action_logprobs.mean()
        self.policy_loss.append(policy_loss)
        #print("Evaluate action_logprobs", action_logprobs)
        dist_entropy = dist.entropy()
        #self.entropy_loss.append(dist_entropy)
        #print("ENTROPY", dist_entropy)
        state_value = self.critic(state)
        #print("state_value", state_value)
        return action_logprobs, torch.squeeze(state_value), dist_entropy

class PPO:
    def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)

        self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def select_action(self, state, memory):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        #print("STATE", state)
        return self.policy_old.act(state, memory).cpu().data.numpy().flatten()

    def update(self, memory):
        # Monte Carlo estimate of rewards:
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards:
        rewards = torch.tensor(rewards).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach()
        old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach()
        old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach()

        # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Evaluating old actions and values :
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # Finding the ratio (pi_theta / pi_theta__old):
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss:
            advantages = rewards - state_values.detach()   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())

    def training_loop(self, env, max_episodes, max_timesteps, update_timestep, render):
        avg_length = 0
        time_step = 0
        self.avg_rewards = []
        self.episode_counter = 0
        self.running_reward = []
        running_rew = 0
        memory = Memory()
        # training loop
        for i_episode in range(1, max_episodes+1):
            state = env.take_observation()
            curr_rewards = []
            self.episode_counter += 1.0
            for t in range(max_timesteps):
                time_step += 1.0
                sys.stdout.write("\r Time_Step {}".format(time_step))
                sys.stdout.flush()
                # Running policy_old:
                action = self.select_action(state, memory)

                #print("AG-action", action)
                state, reward, done, info = env.step(action)

                # Saving reward and is_terminals:
                memory.rewards.append(reward)
                memory.is_terminals.append(done)
                running_rew += reward
                curr_rewards.append(reward)
                # update if its time
                if time_step % update_timestep == 0:
                    self.update(memory)
                    memory.clear_memory()
                    time_step = 0
                running_rew += reward
                if render:
                    env.render()
                if done:
                    break
            self.running_reward.append(running_rew)
            running_rew = 0
            avg_length += t
            print("Episode", self.episode_counter)
            pprint.pprint(info)

            # stop training if avg_reward > solved_reward
            # if running_reward > (log_interval*solved_reward):
            #     print("########## Solved! ##########")
            #     torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env))
            #     break

            # save every 500 episodes
            if i_episode % 2000 == 0:
                torch.save(self.policy.state_dict(), './PPO_continuous_{}.pth'.format(env))

            # logging
            # if i_episode % log_interval == 0:
            #     avg_length = int(avg_length/log_interval)
            #     running_reward = int((running_reward/log_interval))

            #     print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))
            #     running_reward = 0
            #     avg_length = 0
            self.avg_rewards.append(np.mean(curr_rewards))
        self.plot_results()

    def plot_results(self):

        policy_loss = self.policy.policy_loss
        #entropy_loss = self.policy.entropy_loss    
        plt.figure(figsize=(15,10))
        gs = gridspec.GridSpec(3, 2)
        ax0 = plt.subplot(gs[0,:])
        ax0.plot(self.episode_counter)
        ax0.plot(self.avg_rewards)
        ax0.set_ylim([-10, 20])
        ax0.set_xlabel('Episodes')
        plt.title('Mean Rewards')

        ax1 = plt.subplot(gs[1,:])
        #ax1 = plt.subplot(gs[1, 0])
        ax1.plot(policy_loss)
        plt.title('Policy Loss')
        plt.xlabel('Update Number')

        ax2 = plt.subplot(gs[2, :])
        ax2.plot(self.episode_counter)
        ax2.plot(self.running_reward)
        ax2.set_ylim([-500, 100])
        plt.title('Sum of rewards per episode')
        plt.xlabel('Update Number')

        plt.tight_layout()
        plt.show()



def main():
    ############## Hyperparameters ##############
    env = Robot2Env()
    render = False
    #solved_reward = 300         # stop training if avg_reward > solved_reward
    #log_interval = 20           # print avg reward in the interval
    max_episodes = 2500        # max training episodes
    max_timesteps = 400        # max timesteps in one episode

    update_timestep = 400      # update policy every n timesteps
    action_std = 0.5            # constant std for action distribution (Multivariate Normal)
    K_epochs = 80               # update policy for K epochs
    eps_clip = 0.2              # clip parameter for PPO
    gamma = 0.99                # discount factor
    lr = 0.0003                 # parameters for Adam optimizer
    betas = (0.9, 0.999)

    random_seed = None
    #############################################

    # creating environment

    state_dim = env.observation_space.shape[0]
    #print("AG-STATE_DIM", state_dim)
    action_dim = env.action_space.shape[0]
    #print("AG-ACTION_DIM", action_dim)
    if random_seed:
        #print("Random Seed: {}".format(random_seed))
        torch.manual_seed(random_seed)
        env.seed(random_seed)
        np.random.seed(random_seed)


    ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)
    #print(lr,betas)
    ppo.training_loop(env, max_episodes, max_timesteps, update_timestep, render)

    # logging variables


if __name__ == '__main__':
    main()

0 个答案:

没有答案