编写损失函数以更新策略网络

时间:2018-11-29 06:21:26

标签: python tensorflow machine-learning keras

我正在编写一个策略网络,以按照this文件中的说明找到本地化策略,并紧随123的指导。我在实现网络丢失时遇到问题。

我特别按照上面的[3]进行操作,但是损失函数的实现似乎有些问题,我似乎无法弄清楚。当我训练网络时,损失不会最小化,网络也不会学习任何策略。任何指导将不胜感激。

政策网络

def __init__(self):
    """ Policy Network for a localisation agent. """

    # Create the Policy Network graph
    self.graph = tf.Graph()
    with self.graph.as_default():

        # Create the Policy Network session
        self.session = tf.Session()
        with self.session.as_default():

            K.set_session(self.session)

            if MODE == "TRAIN":

                """ Define the Policy Network's architecture. """
                self.input = layers.Input(shape=(INPUT_DIM_PN, ))
                net = self.input

                net = layers.Dense(1024, kernel_initializer="he_normal")(net)
                net = layers.Activation("relu")(net)
                net = layers.Dropout(0.2)(net)

                net = layers.Dense(1024, kernel_initializer="he_normal")(net)
                net = layers.Activation("relu")(net)
                net = layers.Dropout(0.2)(net)

                net = layers.Dense(ACTIONS, kernel_initializer="he_normal")(net)
                net = layers.Activation("softmax")(net)

                self.model = Model(inputs=self.input, outputs=net)

                # Print a summary of the network architecture
                self.model.summary(line_length=200)

            if MODE == "TEST":

                """ Load the trained Policy Network from file. """
                # self.model = models.load_model(MODEL_DIR)
                self.model = tf.keras.models.load_model(MODEL_DIR)
                print("Policy Network successfully loaded.")

            # Build the train function
            self.build_train_fn()

训练功能

def build_train_fn(self):
    """ Train function for the Policy Network.

    Defines metrics in order to train the Policy Network
    to maximise the expected reward of a localisation sequence.
    This function replaces model.fit(X, y).
    """

    with self.graph.as_default():

        K.set_session(self.session)

        action_prob_placeholder = self.model.output
        action_onehot_placeholder = K.placeholder(shape=(None, ACTIONS), name="action_onehot")
        discount_reward_placeholder = K.placeholder(shape=(None,), name="discount_reward")

        action_prob = K.sum(action_prob_placeholder * action_onehot_placeholder, axis=1)
        log_action_prob = K.log(action_prob)

        loss = (-log_action_prob) * discount_reward_placeholder
        loss = K.mean(loss)

        RMSprop = optimizers.RMSprop(lr=learning_rate, decay=0.99)

        updates = RMSprop.get_updates(params=self.model.trainable_weights,
                                      loss=loss)

        self.train_fn = K.function(inputs=[self.model.input,
                                           action_onehot_placeholder,
                                           discount_reward_placeholder],
                                   outputs=[],
                                   updates=updates)

0 个答案:

没有答案