
时间:2017-04-12 18:58:32

标签: python machine-learning tensorflow reinforcement-learning q-learning

我已经检查过this question并确认这不是重复问题。






     # Clip gradients to prevent gradient explosion
    gradients = self.optimizer.compute_gradients(self.loss)
    clipped_gradients = [(tf.clip_by_value(grad,-1.,1.), var) for grad, var in gradients]
    self.update_model = self.optimizer.apply_gradients(clipped_gradients)



 def td_update(self, current_state, last_action, next_state, reward):
    """Updates the Q_function according to the SARSA update algorithm"""
    # Update the replay table
    self.replay_table[self.transition_count % self.replay_size] = (current_state, last_action, next_state, reward)
    self.transition_count = (self.transition_count + 1)

    # Don't start learning until transition table has some data
    if self.transition_count >= self.update_size * 20:
        if self.transition_count == self.update_size * 20:
            print("Replay Table is Ready\n")

        # Get a random subsection of the replay table for mini-batch update
        random_tbl = random.choice(self.replay_table[:min(self.transition_count,self.replay_size)],size=self.update_size)
        feature_vectors = np.vstack(random_tbl['state'])
        actions = random_tbl['action']
        next_feature_vectors = np.vstack(random_tbl['next_state'])
        rewards = random_tbl['reward']

        # Get the indices of the non-terminal states
        non_terminal_ix = np.where([~np.any(np.isnan(next_feature_vectors),axis=(1,2,3))])[1]

        q_current = self.get_Q_values(feature_vectors)
        # Default q_next will be all zeros (this encompasses terminal states)
        q_next = np.zeros([self.update_size,len(self._environment.action_list)])
        q_next[non_terminal_ix] = self.get_Q_values(next_feature_vectors[non_terminal_ix])

        # The target should be equal to q_current in every place
        target = q_current.copy()

        # Only actions that have been taken should be updated with the reward
        # This means that the target - q_current will be [0 0 0 0 0 0 x 0 0....] 
        # so the gradient update will only be applied to the action taken
        # for a given feature vector.
        target[np.arange(len(target)), actions] += (rewards + self.gamma*q_next.max(axis=1))

        # Logging
        if self.log_file is not None:
            print ("Current Q Value: {}".format(q_current),file=self.log_file)
            print ("Next Q Value: {}".format(q_next),file=self.log_file)
            print ("Current Rewards: {}".format(rewards),file=self.log_file)
            print ("Actions: {}".format(actions),file=self.log_file)
            print ("Targets: {}".format(target),file=self.log_file)

            # Log some of the gradients to check for gradient explosion
            loss, output_grad, conv_grad = self.sess.run([self.loss,self.output_gradient,self.convolutional_gradient],
                                                         feed_dict={self.target_Q: target, self.input_matrix: feature_vectors})
            print ("Loss: {}".format(loss),file=self.log_file)
            print ("Output Weight Gradient: {}".format(output_grad),file=self.log_file)
            print ("Convolutional Gradient: {}".format(conv_grad),file=self.log_file)

        # Update the model
        self.sess.run(self.update_model, feed_dict={self.target_Q: target, self.input_matrix: feature_vectors})


0 个答案:
