tensorflow rl 代理 ValueError,错误的输入形状但在哪里?

时间:2021-07-21 19:14:33

标签: python tensorflow keras reinforcement-learning openai-gym

我正在尝试在 2048 游戏中训练一个强化代理。我自己设计了 Env。该错误表示输入必须为 (1, 16) 形状,但数组以 (1, 4) 形状传递。我无法弄清楚在我的代码中哪里传递了具有这种形状的数组:

class CardGameEnv(Env):
    def __init__(self):
        self.action_space = Discrete(3)
        
        self.observation_space = Box(low=np.array([0 for _ in range(16)]), high=np.array([np.inf for _ in range(16)]))
        
        self._state = [0 for _ in range(16)]
        
        self._episode_ended = False
        
        self._score = 0

    def action_spec(self):
        return self.action_space

    def observation_spec(self):
        return self.observation_space

    def reset(self):
        self._state = [0 for _ in range(16)]
        
        self._episode_ended = False
        
        restart()
        print(np.array([self._state]))
        return ts.restart(np.array([self._state], dtype=np.int32))

    def step(self, action):

        if self._episode_ended:
          # The last action ended the episode. Ignore the current action and start
          # a new episode.
            return self.reset()
        
        old_score = self._score
        print(action)
        make_move(action)
        
        if check_game_over():
            self._episode_ended = True
        
        self._score = get_score()
        
        self._state = list(get_board().flatten())
        
        score = get_score()
        
        if self._episode_ended:
            reward = - 10
            return self._state, reward, self._episode_ended, {}
        
        elif score == old_score:
            reward = -10
            return self._state, reward, self._episode_ended, {} 
        
        else:
            reward = self._score - old_score
            return self._state, reward, self._episode_ended, {}

def build_model(states, actions):
    model = Sequential()
    model.add(Flatten(input_shape=states))
    model.add(Dense(24, activation='relu'))
    model.add(Dense(24, activation='relu'))
    model.add(Dense(actions, activation='softmax'))
    return model

def build_agent(model, actions):
    policy = BoltzmannQPolicy()
    memory = SequentialMemory(limit=50000, window_length=1)
    dqn = DQNAgent(model=model, memory=memory, policy=policy, 
                  nb_actions=actions, nb_steps_warmup=10, target_model_update=1e-2)
    return dqn

env = CardGameEnv()
model = build_model((1, 16), 4) 

dqn = build_agent(model, 4)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
dqn.fit(env, nb_steps=50000, visualize=False, verbose=1)

这是我代码的重要部分。根据错误,我传递了一个形状为 (1, 4) 的数组,但是我到处都有一个带有 (1, 16) 的数组用于 state

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-65-edd0afcc8057> in <module>
      8 dqn = build_agent(model, 4)
      9 dqn.compile(Adam(lr=1e-3), metrics=['mae'])
---> 10 dqn.fit(env, nb_steps=50000, visualize=False, verbose=1)

~\anaconda3\lib\site-packages\rl\core.py in fit(self, env, nb_steps, action_repetition, callbacks, verbose, visualize, nb_max_start_steps, start_step_policy, log_interval, nb_max_episode_steps)
    166                 # This is were all of the work happens. We first perceive and compute the action
    167                 # (forward step) and then use the reward to improve (backward step).
--> 168                 action = self.forward(observation)
    169                 if self.processor is not None:
    170                     action = self.processor.process_action(action)

~\anaconda3\lib\site-packages\rl\agents\dqn.py in forward(self, observation)
    222         # Select an action.
    223         state = self.memory.get_recent_state(observation)
--> 224         q_values = self.compute_q_values(state)
    225         if self.training:
    226             action = self.policy.select_action(q_values=q_values)

~\anaconda3\lib\site-packages\rl\agents\dqn.py in compute_q_values(self, state)
     66 
     67     def compute_q_values(self, state):
---> 68         q_values = self.compute_batch_q_values([state]).flatten()
     69         assert q_values.shape == (self.nb_actions,)
     70         return q_values

~\anaconda3\lib\site-packages\rl\agents\dqn.py in compute_batch_q_values(self, state_batch)
     61     def compute_batch_q_values(self, state_batch):
     62         batch = self.process_state_batch(state_batch)
---> 63         q_values = self.model.predict_on_batch(batch)
     64         assert q_values.shape == (len(state_batch), self.nb_actions)
     65         return q_values

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_v1.py in predict_on_batch(self, x)
   1203           ' tf.distribute.Strategy.')
   1204     # Validate and standardize user data.
-> 1205     inputs, _, _ = self._standardize_user_data(
   1206         x, extract_tensors_from_dataset=True)
   1207     # If `self._distribution_strategy` is True, then we are in a replica context

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_v1.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2345       return [], [], None
   2346 
-> 2347     return self._standardize_tensors(
   2348         x, y, sample_weight,
   2349         run_eagerly=run_eagerly,

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_v1.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
   2373     if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
   2374       # TODO(fchollet): run static checks with dataset output shape(s).
-> 2375       x = training_utils_v1.standardize_input_data(
   2376           x,
   2377           feed_input_names,

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_utils_v1.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    661         for dim, ref_dim in zip(data_shape, shape):
    662           if ref_dim != dim and ref_dim is not None and dim is not None:
--> 663             raise ValueError('Error when checking ' + exception_prefix +
    664                              ': expected ' + names[i] + ' to have shape ' +
    665                              str(shape) + ' but got array with shape ' +

ValueError: Error when checking input: expected flatten_4_input to have shape (1, 16) but got array with shape (1, 4)

0 个答案:

没有答案