keras-rl的处理器类会改变形状

时间:2019-03-26 09:12:25

标签: python keras openai-gym keras-rl

好吧,我正在尝试为keras-rl的模型提供10个整数的列表,但是,由于我正在使用OpenAI-Gym的新环境,因此我需要设置处理器我想上课。我的处理器类如下所示:

class RecoProcessor(Processor):
    def process_observation(self, observation):
        print("Observation:")
#         print(observation.shape)
        look_back = 10
        if observation is None:
            X=np.zeros(10)
        else:
            X=np.array(observation, dtype='float32')
#         X.append(np.zeros{look_back - len(X)})
        print(X.shape)
        return X

    def process_state_batch(self, batch):
        print("Batch:")
        print(batch.shape)
        return batch

    def process_reward(self, reward):
        return reward

    def process_demo_data(self, demo_data):
        for step in demo_data:
            step[0] = self.process_observation(step[0])
            step[2] = self.process_reward(step[2])
        return demo_data

我的特工和模特如下:

window_length = 1
emb_size = 100
look_back = 10

# "Expert" (regular dqn) model architecture
expert_model = Sequential()
# expert_model.add(Input(shape=(look_back,window_length)))
expert_model.add(Embedding(env.action_space.n+1, emb_size, input_length=look_back,mask_zero=True))
expert_model.add(LSTM(64, input_shape=(look_back,window_length)))
expert_model.add(Dense(env.action_space.n, activation='softmax'))

# try using different optimizers and different optimizer configs
expert_model.compile(loss='mse',
              optimizer='adam',
              metrics=['acc'])

# memory
memory = PrioritizedMemory(limit=5000,  window_length=window_length)

# policy
policy = BoltzmannQPolicy()

# agent
dqn = DQNAgent(model=expert_model, nb_actions=env.action_space.n, policy=policy, memory=memory, 
               enable_double_dqn=False, enable_dueling_network=False, gamma=.6, 
               target_model_update=1e-2, nb_steps_warmup=100, processor = RecoProcessor())

但是当您尝试执行此操作时,我可以看到输出是这样的:

Training for 50000 steps ...
CCCCCCCCCCCC
(10,)
Interval 1 (0 steps performed)
AAAAAAAAAAAAAAA
(1, 1, 10)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-29-4d8fdf0e849e> in <module>
     32 dqn.compile(Adam(lr), metrics=['mae'])
     33 
---> 34 train = dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)
     35 np.savetxt(fichero_train_history, 
     36            np.array(train.history["episode_reward"]), delimiter=",")

c:\users\eloy.anguiano\src\keras-rl\rl\core.py in fit(self, env, nb_steps, action_repetition, callbacks, verbose, visualize, nb_max_start_steps, start_step_policy, log_interval, nb_max_episode_steps)
    167                 # This is were all of the work happens. We first perceive and compute the action
    168                 # (forward step) and then use the reward to improve (backward step).
--> 169                 action = self.forward(observation)
    170                 if self.processor is not None:
    171                     action = self.processor.process_action(action)

c:\users\eloy.anguiano\src\keras-rl\rl\agents\dqn.py in forward(self, observation)
     87         # Select an action.
     88         state = self.memory.get_recent_state(observation)
---> 89         q_values = self.compute_q_values(state)
     90         if self.training:
     91             action = self.policy.select_action(q_values=q_values)

c:\users\eloy.anguiano\src\keras-rl\rl\agents\dqn.py in compute_q_values(self, state)
     67 
     68     def compute_q_values(self, state):
---> 69         q_values = self.compute_batch_q_values([state]).flatten()
     70         assert q_values.shape == (self.nb_actions,)
     71         return q_values

c:\users\eloy.anguiano\src\keras-rl\rl\agents\dqn.py in compute_batch_q_values(self, state_batch)
     62     def compute_batch_q_values(self, state_batch):
     63         batch = self.process_state_batch(state_batch)
---> 64         q_values = self.model.predict_on_batch(batch)
     65         assert q_values.shape == (len(state_batch), self.nb_actions)
     66         return q_values

~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras-2.2.4-py3.7.egg\keras\engine\training.py in predict_on_batch(self, x)
   1266             Numpy array(s) of predictions.
   1267         
-> 1268         x, _, _ = self._standardize_user_data(x)
   1269         if self._uses_dynamic_learning_phase():
   1270             ins = x + [0.]

~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras-2.2.4-py3.7.egg\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    749             feed_input_shapes,
    750             check_batch_axis=False,  # Don't enforce the batch size.
--> 751             exception_prefix='input')
    752 
    753         if y is not None:

~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras-2.2.4-py3.7.egg\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    126                         ': expected ' + names[i] + ' to have ' +
    127                         str(len(shape)) + ' dimensions, but got array '
--> 128                         'with shape ' + str(data_shape))
    129                 if not check_batch_axis:
    130                     data_shape = data_shape[1:]

ValueError: Error when checking input: expected embedding_12_input to have 2 dimensions, but got array with shape (1, 1, 10)

正如您所看到的,我得到的形状是批处理中的一种,我不知道如何求解。万一您想进行一些试验,我使用的环境是RecoGym(版本1)。

1 个答案:

答案 0 :(得分:0)

您可以使用自己的处理器删除不必要的尺寸。

from rl.core import Processor


class CustomProcessor(Processor):
    '''
    acts as a coupling mechanism between the agent and the environment
    '''

    def process_state_batch(self, batch):
        '''
        Given a state batch, I want to remove the second dimension, because it's
        useless and prevents me from feeding the tensor into my CNN
        '''
        return numpy.squeeze(batch, axis=1)

然后将其传递给您的代理,例如:

agent = DQNAgent(... , processor=CustomProcessor())

收到此论坛的回复:https://github.com/keras-rl/keras-rl/issues/113