我正在编写一个策略网络,以按照this文件中的说明找到本地化策略,并紧随1,2和3的指导。我在实现网络丢失时遇到问题。
我特别按照上面的[3]进行操作,但是损失函数的实现似乎有些问题,我似乎无法弄清楚。当我训练网络时,损失不会最小化,网络也不会学习任何策略。任何指导将不胜感激。
政策网络
def __init__(self):
""" Policy Network for a localisation agent. """
# Create the Policy Network graph
self.graph = tf.Graph()
with self.graph.as_default():
# Create the Policy Network session
self.session = tf.Session()
with self.session.as_default():
K.set_session(self.session)
if MODE == "TRAIN":
""" Define the Policy Network's architecture. """
self.input = layers.Input(shape=(INPUT_DIM_PN, ))
net = self.input
net = layers.Dense(1024, kernel_initializer="he_normal")(net)
net = layers.Activation("relu")(net)
net = layers.Dropout(0.2)(net)
net = layers.Dense(1024, kernel_initializer="he_normal")(net)
net = layers.Activation("relu")(net)
net = layers.Dropout(0.2)(net)
net = layers.Dense(ACTIONS, kernel_initializer="he_normal")(net)
net = layers.Activation("softmax")(net)
self.model = Model(inputs=self.input, outputs=net)
# Print a summary of the network architecture
self.model.summary(line_length=200)
if MODE == "TEST":
""" Load the trained Policy Network from file. """
# self.model = models.load_model(MODEL_DIR)
self.model = tf.keras.models.load_model(MODEL_DIR)
print("Policy Network successfully loaded.")
# Build the train function
self.build_train_fn()
训练功能
def build_train_fn(self):
""" Train function for the Policy Network.
Defines metrics in order to train the Policy Network
to maximise the expected reward of a localisation sequence.
This function replaces model.fit(X, y).
"""
with self.graph.as_default():
K.set_session(self.session)
action_prob_placeholder = self.model.output
action_onehot_placeholder = K.placeholder(shape=(None, ACTIONS), name="action_onehot")
discount_reward_placeholder = K.placeholder(shape=(None,), name="discount_reward")
action_prob = K.sum(action_prob_placeholder * action_onehot_placeholder, axis=1)
log_action_prob = K.log(action_prob)
loss = (-log_action_prob) * discount_reward_placeholder
loss = K.mean(loss)
RMSprop = optimizers.RMSprop(lr=learning_rate, decay=0.99)
updates = RMSprop.get_updates(params=self.model.trainable_weights,
loss=loss)
self.train_fn = K.function(inputs=[self.model.input,
action_onehot_placeholder,
discount_reward_placeholder],
outputs=[],
updates=updates)