我正在尝试使用torch.multiprocessing用多个过程训练我的神经网络,并且在尝试在特定张量上评估我的神经网络时,我的代码停止了。我将此问题隔离到以下代码中(直到“ Class DQN”之前,不要阅读第一部分,它只是格式化数据的一些代码)。
代码结构:
-一些用于格式化数据的类:FireResetEnv,MaxAndSkipEnv,ProcessFrame84,BufferWrapper,ImageToPyTorch,ScaledFloatFrame
-类“ DQN”:实现神经网络
-函数“ play”:用于计算特定张量中nn的函数
-main:运行多个进程
import torch
import torch.nn as nn
import gym
import gym.spaces
import numpy as np
import torch.multiprocessing as mp
import collections
import cv2
#DO NOT READ UNTIL "CLASS DQN"
class FireResetEnv(gym.Wrapper):
def __init__(self, env=None):
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def step(self, action):
return self.env.step(action)
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every 'skip'-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
self._skip = skip
def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def _reset(self):
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ProcessFrame84(gym.ObservationWrapper):
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def observation(self, obs):
return ProcessFrame84.process(obs)
@staticmethod
def process(frame):
if frame.size == 210 * 160 * 3:
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
elif frame.size == 250 * 160 * 3:
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps, dtype=np.float32):
super(BufferWrapper, self).__init__(env)
self.dtype = dtype
old_space = env.observation_space
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaledFloatFrame(gym.ObservationWrapper):
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
def make_env(env_name):
env = gym.make(env_name)
env = MaxAndSkipEnv(env)
env = FireResetEnv(env)
env = ProcessFrame84(env)
env = ImageToPyTorch(env)
env = BufferWrapper(env, 4)
return ScaledFloatFrame(env)
env = make_env('PongNoFrameskip-v4')
#START READING HERE#
class DQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(DQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x)
conv_out = conv_out.view(x.size()[0], -1)
return self.fc(conv_out)
def play(Q, rank):
obs = env.reset()
print(1)
print(Q(torch.FloatTensor([obs])))
print(2)
if __name__ == "__main__":
num_processes = 4
device = torch.device("cpu")
Q = DQN(env.observation_space.shape, env.action_space.n)
QHat = DQN(env.observation_space.shape, env.action_space.n)
Q.share_memory()
QHat.share_memory()
frame_id = mp.Value('i', 0)
processes = []
for rank in range(num_processes):
p = mp.Process(target = play, args = (Q, rank))
p.start()
processes.append(p)
for p in processes:
p.join()
运行此代码时,它会按预期输出四次“ 1”,然后停止(它应该四次输出求值Q(obs)和“ 2”)