具有优先体验重放功能和目标网络的DQN不会得到改善

时间:2018-10-21 20:47:03

标签: python tensorflow keras reinforcement-learning openai-gym

我试图编写一个神经网络来使用Tensorflow和Keras解决OpenAI的CartPole环境。该网络使用优先级的经验重播和一个单独的目标网络,该网络每十分之一就会更新一次。这是代码:

import numpy as np
import gym
import matplotlib.pyplot as plt
from collections import deque

from tensorflow import keras

class agent:
    def __init__(self,inputs,outputs,gamma):
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_decay = 0.999
        self.epsilon_min = 0.01
        self.inputs = inputs
        self.outputs = outputs
        self.network = self._build_net()
        self.target_network = self._build_net()
        self.replay_memory = deque(maxlen=100000)

        self.target_network.set_weights(self.network.get_weights())

    def _build_net(self):
        model = keras.models.Sequential()

        model.add(keras.layers.Dense(10,activation='relu',input_dim=self.inputs))
        model.add(keras.layers.Dense(10,activation='relu'))
        model.add(keras.layers.Dense(10,activation='relu'))
        model.add(keras.layers.Dense(self.outputs,activation='sigmoid'))
        model.compile(optimizer='adam',loss='categorical_crossentropy')

        return model

    def act(self,state,testing=False):
        if not testing:
            if self.epsilon > np.random.rand():
                return np.random.randint(self.outputs)
            else:
                return np.argmax(self.network.predict(np.array([state])))

        else:
            return np.argmax(self.network.predict(np.array([state])))

    def remember(self,state,action,next_state,reward,win):
        self.replay_memory.append((state,action,next_state,reward,win))

    def get_batch(self):
        batch = []
        losses = []
        targets = []

        if len(self.replay_memory) >= 32:
            for state,action,next_state,reward,done in self.replay_memory:
                target_f = np.zeros([1,self.outputs])
                if done != False:
                    target = (reward + self.gamma * np.amax(self.target_network.predict(np.array([next_state]))[0]))
                    target_f[0][action] = target

                targets.append(target_f)
                loss = np.mean(self.network.predict(np.array([state]))-target_f)**2
                losses.append(loss)

            indexes = np.argsort(losses)[:32]

            for indx in indexes:
                batch.append((self.replay_memory[indx][0], targets[indx]))

            return batch

    def replay(self,batch):
        for state,target in batch:
            self.network.fit(np.array([state]), target, epochs=1, verbose=0)
        self.epsilon = max(self.epsilon_min, self.epsilon*self.epsilon_decay)

    def update_target_network(self):
        self.target_network.set_weights(self.network.get_weights())

    def save(self):
        self.network.save_weights('./DQN.model')

    def load(self):
        self.network.load_weights('./DQN.model')

env = gym.make('CartPole-v0')
episodes = 1000
agent = agent(env.observation_space.shape[0],env.action_space.n,0.99)
rewards = []

state = env.reset()
for _ in range(500):
    action = agent.act(state,1.0)
    next_state, reward, win, _ = env.step(action)
    agent.remember(state,action,next_state,reward,win)
    state = next_state

    if win:
        state = env.reset()

for e in range(episodes):
    print('episode:',e+1)
    batch = agent.get_batch()
    state = env.reset()
    for t in range(400):
        action = agent.act(state)
        next_state, reward, win, _ = env.step(action)
        agent.remember(state,action,next_state,reward,win)
        state = next_state

        if win:
            break

        agent.replay(batch)

    print('score:',t)
    print('epsilon:',agent.epsilon)
    print('')

    if e%10 == 0:
        agent.update_target_network()
        agent.save()

    rewards.append(t)
    plt.plot(list(range(e+1)),rewards)
    plt.savefig('./reward.png')

问题在于随着epsilon的减少,药剂变得更糟。而且当epsilon是其最小值时,代理仅在极点下降之前经过7-9步,如下图所示。有人可以告诉我为什么我的经纪人什么都不学以及如何解决吗?

enter image description here

1 个答案:

答案 0 :(得分:0)

尝试使奖励=-奖励(如果完成)。有帮助