pytorch模型不更新权重

时间:2019-10-10 03:26:38

标签: python pytorch

我正在尝试用pytorch解决CartPole问题,但是经过多次迭代后,参数不会更新。

我要重现的代码是用喀拉拉邦制成的[[https://github.com/gsurma/cartpole/blob/master/cartpole.py]”。

import random
from collections import deque

import gym
import numpy as np
import torch.nn as nn
import torch

GAMMA = 0.95

MEMORY_SIZE = 1000000
BATCH_SIZE = 50

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995


# Principal Neural Netword module
class DQNSolver(nn.Module):
    def __init__(self, observation_space, action_space):
        super(DQNSolver, self).__init__()
        self.action_space = action_space
        self.observation_space = observation_space

        self.hiddenSpace = 24

        self.fc1 = nn.Linear(self.observation_space, self.hiddenSpace)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(self.hiddenSpace, self.action_space)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out


class cartpole():

    def __init__(self):

        # Device configuration
        device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')

        self.env = gym.make("CartPole-v1")
        self.observation_space = self.env.observation_space.shape[0]
        self.action_space = self.env.action_space.n

        self.dqn_solver = DQNSolver(self.observation_space, self.action_space).to(device)

        # Create the memory
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.exploration_rate = EXPLORATION_MAX

        # Create the optimizer and the loss
        self.optimizer = torch.optim.Adam(self.dqn_solver.parameters(), lr=0.1)
        self.loss_func = torch.nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # This custom function will receive a observation of the state of the
    # environment, and then return the action
    def predict_action(self, observation):
        # if np.random.rand() < self.exploration_rate:
        #     randomQ = np.random.rand(1, self.action_space)[0]
        #     return np.double(randomQ)
        predicted = self.dqn_solver(observation)
        predicted = predicted.cpu().data.numpy()
        return np.double(predicted)

    def optimize_model(self, state, q_values):
        output = self.predict_action(state)
        output = torch.tensor(output, requires_grad=False)
        qValues = torch.tensor(q_values, requires_grad=True)

        self.optimizer.zero_grad()
        loss = self.loss_func(output,  qValues)
        loss.backward()
        self.optimizer.step()


        print("Loss: {}".format(loss))

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                next_action = self.predict_action(state_next)
                q_update = (reward + GAMMA * np.amax(next_action))
            q_values = self.predict_action(state)
            q_values[action] = q_update

            self.optimize_model(state, q_values)
        print("Finished replay")
        print('weights after backpropagation = ',   list(self.dqn_solver.parameters()))
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

    def run(self):
        while True:
            state = self.env.reset()
            # state = np.reshape(state, [1, self.observation_space])
            state = torch.Tensor(state)
            while True:
                self.env.render()
                action = self.predict_action(state)
                action = np.argmax(action)
                state_next, reward, terminal, info = self.env.step(action)
                reward = reward if not terminal else -reward
                state_next = torch.Tensor(state_next)
                self.remember(state, action, reward, state_next, terminal)
                state = state_next
                if terminal:
                    break
                self.experience_replay()


if __name__ == "__main__":
    cartpole().run()

在几个时期之后,此行中打印的参数是相同的:

print('weights after backpropagation = ',   list(self.dqn_solver.parameters()))

损耗值也更接近随机值,但不会增加或减少。 怎么了?

1 个答案:

答案 0 :(得分:0)

问题

optimize_model中,将q_value替换为新的张量,该张量不是原始计算图的节点,因此无法将梯度传递回网络。参见下面的示例,

import torch
import torch.nn as nn
import torch.optim as optim

clf = nn.Linear(2, 2)
opt = optim.SGD(clf.parameters(), lr=0.1)
crit = nn.MSELoss()

input = torch.arange(2).float().view(-1, 2)
label = torch.arange(2).float().view(-1, 2)

pred = clf(input)
pred_copy = torch.tensor(pred, requires_grad=True)

opt.zero_grad()
loss_wrong = crit(pred_copy, label)
loss_wrong.backward()
for p in clf.parameters():
    print(p.grad)

opt.zero_grad()
loss_correct = crit(pred, label)
loss_correct.backward()
for p in clf.parameters():
    print(p.grad)

输出:

None
None
tensor([[-0.0000, -0.5813],
        [-0.0000, -0.9274]])
tensor([-0.5813, -0.9274])

从上面的快速示例中,您可以看到根据复制的预测计算出的梯度无法传递给网络参数。


解决方案

def optimize_model(self, state, q_values):
    output = self.predict_action(state)
    output.requires_grad = False
    qValues.requires_grad = True

此外,正如@JoshVarty在评论中提到的那样,您不应将计算图中的任何张量转换为numpy并将其转换回。那会破坏图形,因此渐变将无法正确传递(计算)。

tl; dr 仅尽可能使用Pytorch内置函数。