我正在使用OpenAI健身房的柱极环境训练强化学习模型。尽管我的权重和模型的.h5文件出现在目标目录中,但是运行以下代码-tf.train.get_checkpoint_state(“ C:/ Users / dgt / Documents”)后,我仍然没有显示。
这是我的完整代码-
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
checkpoint = tf.train.get_checkpoint_state("C:/Users/dgt/Documents")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = keras.models.load_model('cartpole.h5')
dqn_solver.model = model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
cartpole_weights.h5和cartpole.h5文件都出现在我的目标目录中。但是,我认为还应该出现另一个名为“检查点”的文件。我的理解是,这就是我的代码无法运行的原因。
答案 0 :(得分:1)
首先,如果您尚未保存权重/模型,则代码将不会运行。因此,我注释掉了以下几行,并运行了脚本来首次生成文件。
checkpoint = tf.train.get_checkpoint_state(".")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
dqn_solver.model.load_weights('cartpole_weights.h5')
请注意,我还修改了上面的代码-之前有一些语法错误。特别是您帖子中的这一行
dqn_solver.model = model.load_weights('cartpole_weights.h5')
可能是造成此问题的原因,因为model.load_weights('file')方法会使模型发生变化(而不是返回模型)。
然后我测试了模型权重是否正确保存/加载。为此,您可以
dqn_solver = DQNSolver(observation_space, action_space)
dqn_solver.model.trainable_variables
查看模型首次制作时的(随机初始化)权重。然后,您可以使用以下任一方法加载重量
dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
或
dqn_solver.model.load_weights('cartpole_weights.h5')
,然后您可以再次查看trainable_variables,以确保它们与初始权重不同,并且它们相等。
保存模型时,它会保存完整的体系结构-层的确切配置。当您保存权重时,它只保存可以在trainable_variables中看到的所有张量列表。 请注意,当您使用load_weights时,需要将其加载到权重所针对的确切体系结构中,否则它将无法正常工作。因此,如果您在DQNSolver中更改了模型架构,然后尝试为旧模型加载load_weights,那么它将无法正常工作。如果您使用load_model,它将重置模型以使其与架构完全相同,并设置权重。
编辑-整个修改后的脚本
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
# checkpoint = tf.train.get_checkpoint_state(".")
# print('checkpoint:', checkpoint)
# if checkpoint and checkpoint.model_checkpoint_path:
# dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
# dqn_solver.model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
#%% to load saved results
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
dqn_solver.model = tf.keras.models.load_model('cartpole.h5') # or
dqn_solver.model.load_weights('cartpole_weights.h5')