我正在乒乓球馆环境中训练 DQN,以复制原始的 DQN“人类级别控制...”论文。我的算法运行良好并在较小的测试环境中收敛,但是在 pong 上进行测试时,每次迭代的训练速度都会大大减慢(开始时为 1 秒/100 帧,10 万次迭代后为 10 秒/100 帧)。这是我的模型和训练函数:
型号:
def create_model(self):
"""
Creates Q approximation model
:param state: state placeholder
:return: Sequential model used to predict Q values
"""
# Action Space
num_actions = self.env.action_space.n
# Functional API
state_shape = list(self.env.observation_space.shape)
inputs = tf.keras.Input(dtype=tf.uint8,
shape=(state_shape[0], state_shape[1], state_shape[2] * self.config.state_history))
x = tf.cast(inputs, tf.float32) / self.config.high
x = tf.keras.layers.Conv2D(32, (8, 8), strides=(4, 4), activation='relu')(x)
x = tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), activation='relu')(x)
x = tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), activation='relu')(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(512, activation='linear')(x)
outputs = tf.keras.layers.Dense(num_actions, activation='linear')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
训练函数:
@tf.function
def update_weights(self, states, actions, rewards, next_states, done_masks):
# Target Q value (Labels)
q_samp = tf.where(done_masks, rewards,
rewards + self.config.gamma * tf.reduce_max(self.target_model(next_states, training=True),
axis=1))
actions_one_hot = tf.one_hot(actions, self.num_actions)
with tf.GradientTape() as tape:
# Calculate Model Output (Predictions)
logits = self.model_1(states, training=True)
q = tf.reduce_sum(tf.multiply(logits, actions_one_hot), axis=1)
# Mean Squared Error Loss
loss = tf.reduce_mean(tf.math.squared_difference(q_samp, q))
# Calculate Gradients and clip by norm value
grads = tape.gradient(loss, self.model_1.trainable_weights)
grads = [tf.clip_by_norm(grad, self.config.clip_val) for grad in grads]
# Apply gradients to model
self.opt.apply_gradients(zip(grads, self.model_1.trainable_weights))
return loss, grads
一些附加说明
这不仅仅是训练步骤。即使使用 env.step(action)
更新环境,最后的时间也比开始的时间长 5 倍。
update_weights()
的所有输入都是 tf 张量,除了 self。除了 self.target_model
之外,其他 self 属性都不会改变,它的权重会定期更新。我已经测试过,没有发现回溯。
当数据集应该只接近 5 GB 时,它也倾向于使用大量内存(~25+ GB)。知道我做错了什么吗?
注释掉 tf.function
可以防止训练从一开始每一步减慢约 50%。