具有PrioritizedMemory的DQNAgent可以适合吗?

时间:2019-03-28 12:10:02

标签: reinforcement-learning q-learning keras

我正在使用DQNAgent和PrioritizedMemory来针对环境进行训练,并通过奖励来弥补它可能过拟合,但是,当环境仅显示新状态或者是内存故障时,真的会发生这种情况吗?

我的代码在下面,仅供您检查是否想要,并告诉我出了什么问题以防万一:

lr = 1e-3
emb_size = 100
look_back = 10

# "Expert" (regular dqn) model architecture

inp = Input(shape=(10,))
emb = Embedding(input_dim=env.action_space.n+1, output_dim = emb_size)(inp) 
rnn = Bidirectional(LSTM(5))(emb)
out = Dense(env.action_space.n, activation='softmax')(rnn)
expert_model = Model(inputs = inp, outputs = out)
expert_model.compile(loss='mse',
              optimizer='adam',
              metrics=['acc'])

# memory
memory = PrioritizedMemory(limit=5000,  window_length=window_length)

# policy
policy = BoltzmannQPolicy()

# agent
dqn = DQNAgent(model=expert_model, nb_actions=env.action_space.n, policy=policy, memory=memory, 
               enable_double_dqn=False, enable_dueling_network=False, gamma=.9, batch_size = 1, #Doesnt work if I change the batch size
               target_model_update=1e-2, processor = RecoProcessor())

dqn.compile(Adam(lr), metrics=['mae'])

train = dqn.fit(env, log_interval = 1000,nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)

我的RecoProcessor类是:

class RecoProcessor(Processor):
    def process_observation(self, observation):
        look_back = 10
        if observation is None:
            X=np.zeros(look_back)
        else:
            if len(observation)>look_back:
                observation = observation[-look_back]
            observation = np.array(observation)
            if len(observation.shape) == 2:
                X = observation[:,1]
            else:
                X = np.array([observation[1]])
        if len(X)<look_back:
            X = np.append(X,np.zeros(look_back-len(X)))
        return X

    def process_state_batch(self, batch):
#         CHECK. i THINK SOMETHING IS WRONG HERE
#         print(batch)
#         print(batch[0])
        return batch[0]

    def process_reward(self, reward):
        return reward

    def process_demo_data(self, demo_data):
        for step in demo_data:
            step[0] = self.process_observation(step[0])
            step[2] = self.process_reward(step[2])
        return demo_data

0 个答案:

没有答案