我正在学习深度强化学习 框架Chainer。
我遵循了一个教程,并获得了以下代码:
def train_dddqn(env):
class Q_Network(chainer.Chain):
def __init__(self, input_size, hidden_size, output_size):
super(Q_Network, self).__init__(
fc1=L.Linear(input_size, hidden_size),
fc2=L.Linear(hidden_size, hidden_size),
fc3=L.Linear(hidden_size, hidden_size // 2),
fc4=L.Linear(hidden_size, hidden_size // 2),
state_value=L.Linear(hidden_size // 2, 1),
advantage_value=L.Linear(hidden_size // 2, output_size)
)
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
def __call__(self, x):
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
hs = F.relu(self.fc3(h))
ha = F.relu(self.fc4(h))
state_value = self.state_value(hs)
advantage_value = self.advantage_value(ha)
advantage_mean = (F.sum(advantage_value, axis=1) / float(self.output_size)).reshape(-1, 1)
q_value = F.concat([state_value for _ in range(self.output_size)], axis=1) + (
advantage_value - F.concat([advantage_mean for _ in range(self.output_size)], axis=1))
return q_value
def reset(self):
self.cleargrads()
Q = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)
Q_ast = copy.deepcopy(Q)
optimizer = chainer.optimizers.Adam()
optimizer.setup(Q)
epoch_num = 50
step_max = len(env.data) - 1
memory_size = 200
batch_size = 50
epsilon = 1.0
epsilon_decrease = 1e-3
epsilon_min = 0.1
start_reduce_epsilon = 200
train_freq = 10
update_q_freq = 20
gamma = 0.97
show_log_freq = 5
memory = []
total_step = 0
total_rewards = []
total_losses = []
start = time.time()
for epoch in range(epoch_num):
pobs = env.reset()
step = 0
done = False
total_reward = 0
total_loss = 0
while not done and step < step_max:
# select act
pact = np.random.randint(3)
if np.random.rand() > epsilon:
pact = Q(np.array(pobs, dtype=np.float32).reshape(1, -1))
pact = np.argmax(pact.data)
# act
obs, reward, done = env.step(pact)
# add memory
memory.append((pobs, pact, reward, obs, done))
if len(memory) > memory_size:
memory.pop(0)
# train or update q
if len(memory) == memory_size:
if total_step % train_freq == 0:
shuffled_memory = np.random.permutation(memory)
memory_idx = range(len(shuffled_memory))
for i in memory_idx[::batch_size]:
batch = np.array(shuffled_memory[i:i + batch_size])
b_pobs = np.array(batch[:, 0].tolist(), dtype=np.float32).reshape(batch_size, -1)
b_pact = np.array(batch[:, 1].tolist(), dtype=np.int32)
b_reward = np.array(batch[:, 2].tolist(), dtype=np.int32)
b_obs = np.array(batch[:, 3].tolist(), dtype=np.float32).reshape(batch_size, -1)
b_done = np.array(batch[:, 4].tolist(), dtype=np.bool)
q = Q(b_pobs)
indices = np.argmax(q.data, axis=1)
maxqs = Q_ast(b_obs).data
target = copy.deepcopy(q.data)
for j in range(batch_size):
Q.reset()
loss = F.mean_squared_error(q, target)
total_loss += loss.data
loss.backward()
optimizer.update()
if total_step % update_q_freq == 0:
Q_ast = copy.deepcopy(Q)
# epsilon
if epsilon > epsilon_min and total_step > start_reduce_epsilon:
epsilon -= epsilon_decrease
# next step
total_reward += reward
pobs = obs
step += 1
total_step += 1
total_rewards.append(total_reward)
total_losses.append(total_loss)
if (epoch + 1) % show_log_freq == 0:
log_reward = sum(total_rewards[((epoch + 1) - show_log_freq):]) / show_log_freq
log_loss = sum(total_losses[((epoch + 1) - show_log_freq):]) / show_log_freq
elapsed_time = time.time() - start
print('\t'.join(map(str, [epoch + 1, epsilon, total_step, log_reward, log_loss, elapsed_time])))
start = time.time()
return Q, total_losses, total_rewards
Q, total_losses, total_rewards = train_dddqn(Environment1(train))
我的问题是如何保存和加载经过良好训练的模型?我知道Kreas具有一些功能,例如:model.save和load_model。
那么,此Chainer代码需要什么指定代码?
答案 0 :(得分:1)
您可以使用serializer
模块来保存/加载链接器模型的参数(Chain
类)。
from chainer import serializers
Q = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)
Q_ast = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)
# --- train Q here... ---
# copy Q parameter into Q_ast by saving Q's parameter and load to Q_ast
serializers.save_npz('my.model', Q)
serializers.load_npz('my.model', Q_ast)
有关详细信息,请参见官方文档
此外,您可以参考chainerrl
,这是一个用于强化学习的链接器库。
chainerrl
具有实用函数copy_param
,可将参数从网络source_link
复制到target_link
。